LeNet--卷积神经网络开山之作
发布日期:2022-09-10 02:37:16
浏览次数:2
分类:技术文章
本文共 6781 字,大约阅读时间需要 22 分钟。
LeNet的网络结构示意图如下所示:
输入 → 卷积 → 池化 → 卷积 → 池化 → 全连接 → 全连接 → 全连接(输出)
model.py:
import torch.nn as nnimport torch.nn.functional as F# input(3, 32, 32),pytorch默认stride=1, padding=0class LeNet(nn.Module): def __init__(self, num_classes=10, init_weights=True): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=0) # 16 * 28 * 28 self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 16 * 14 * 14 self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5) # 32 * 10 * 10 self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 32 * 5 * 5 self.fc1 = nn.Linear(32*5*5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.relu(self.conv1(x)) # input(3, 32, 32) output(16, 28, 28) x = self.pool1(x) # output(16, 14, 14) x = F.relu(self.conv2(x)) # output(32, 10, 10) x = self.pool2(x) # output(32, 5, 5) x = x.view(-1, 32*5*5) # output(32*5*5) 注意这里的尺寸一定要和训练时Resize的图片尺寸一致 x = F.relu(self.fc1(x)) # output(120) x = F.relu(self.fc2(x)) # output(84) x = self.fc3(x) # output(10) return x
train.py:
import jsonimport sysimport torchimport torchvisionimport torch.nn as nnfrom matplotlib import pyplot as pltfrom torch.utils.data import DataLoaderfrom torchvision.datasets import ImageFolderfrom model import LeNetimport torch.optim as optimimport torchvision.transforms as transformsfrom tqdm import tqdmimport osos.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'# 解决中文显示问题plt.rcParams['font.sans-serif'] = ['SimHei']plt.rcParams['axes.unicode_minus'] = FalseROOT_TRAIN = r'E:/cnn/AlexNet/data/train'ROOT_TEST = r'E:/cnn/AlexNet/data/val'def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("using {} device.".format(device)) data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), "val": transforms.Compose([transforms.Resize((32, 32)), # cannot 224, must (224, 224) transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} # 数据预处理 train_dataset = ImageFolder(ROOT_TRAIN, transform=data_transform["train"]) # 加载训练集 train_num = len(train_dataset) # 打印训练集有多少张图片 animal_list = train_dataset.class_to_idx # 获取类别名称以及对应的索引 cla_dict = dict((val, key) for key, val in animal_list.items()) # 将上面的键值对位置对调一下 json_str = json.dumps(cla_dict, indent=4) # 把类别和对应的索引写入根目录下class_indices.json文件中 with open('class_indices.json', 'w') as json_file: json_file.write(json_str) batch_size = 32 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) validate_dataset = ImageFolder(ROOT_TEST, transform=data_transform["val"]) # 载入测试集 val_num = len(validate_dataset) # 打印测试集有多少张图片 validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=16, shuffle=False, num_workers=0) print("using {} images for training, {} images for validation.".format(train_num, val_num)) # 用于打印总的训练集数量和验证集数量 # 用于查看数据集,注意改一下上面validate_loader的batch_size,batch_size等几就是一次查看几张图片,shuffle=True顺序打乱一下 # test_data_iter = iter(validate_loader) # test_image, test_label = test_data_iter.next() # # def imshow(img): # img = img / 2 + 0.5 # unnormalize # npimg = img.numpy() # plt.imshow(np.transpose(npimg, (1, 2, 0))) # plt.show() # # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4))) # imshow(utils.make_grid(test_image)) net = LeNet(num_classes=10, init_weights=True) # 实例化网络,num_classes代表有几个类别 net.to(device) # 将网络指认到GPU或CPU上 loss_function = nn.CrossEntropyLoss() # pata = list(net.parameters()) optimizer = optim.Adam(net.parameters(), lr=0.0002) epochs = 10 save_path = './LeNet.pth' best_acc = 0.0 train_steps = len(train_loader) for epoch in range(epochs): # train net.train() running_loss = 0.0 train_bar = tqdm(train_loader, file=sys.stdout) for step, data in enumerate(train_bar): images, labels = data optimizer.zero_grad() outputs = net(images.to(device)) loss = loss_function(outputs, labels.to(device)) loss.backward() optimizer.step() # print statistics running_loss += loss.item() train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss) # validate net.eval() acc = 0.0 # accumulate accurate number / epoch with torch.no_grad(): val_bar = tqdm(validate_loader, file=sys.stdout) for val_data in val_bar: # 遍历验证集 val_images, val_labels = val_data # 数据分为图片和标签 outputs = net(val_images.to(device)) # 将图片指认到设备上传入网络进行正向传播并得到输出 predict_y = torch.max(outputs, dim=1)[1] # 求得输出预测中最有可得的类别(概率最大值) acc += torch.eq(predict_y, val_labels.to(device)).sum().item() # 将预测标签与真实标签进行比对,求得总的预测正确数量 val_accurate = acc / val_num # 预测正确数量/测试集总数量 print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, running_loss / train_steps, val_accurate)) if val_accurate > best_acc: best_acc = val_accurate torch.save(net.state_dict(), save_path) print('Finished Training')if __name__ == '__main__': main()
predict:
import torchimport torchvision.transforms as transformsfrom PIL import Imagefrom model import LeNetdef main(): transform = transforms.Compose( [transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) classes = ('Cat', 'Dog') # 这里是二分类问题 net = LeNet() net.load_state_dict(torch.load('Lenet.pth')) im = Image.open('1.jpg') im = transform(im) # [C, H, W] im = torch.unsqueeze(im, dim=0) # [N, C, H, W] with torch.no_grad(): outputs = net(im) predict = torch.max(outputs, dim=1)[1].numpy() print(classes[int(predict)])if __name__ == '__main__': main()
转载地址:https://blog.csdn.net/m0_56247038/article/details/125081619 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
逛到本站,mark一下
[***.202.152.39]2024年03月20日 02时49分24秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
$.ajax() 方法
2019-04-21
JAVA经典算法40例
2019-04-21
HDU- 2063 过山车
2019-04-21
cogs 7. 通信线路
2019-04-21
项目论证
2019-04-21
JavaScript
2019-04-21
iflab隔壁ios组新生面试题
2019-04-21
App两个页面之间的正反传值方法
2019-04-21
数兔子
2019-04-21
python宽度_在Python中对齐框架宽度
2019-04-21
复制关联表mysql_mysql关联表的复制
2019-04-21
java mysql 表关系分析_数据库表的关系
2019-04-21
c语言 变量 函数命名 风格_C语言static变量和函数
2019-04-21
mysql男女字段应该建立索引吗_那些字段适不适合建索引?
2019-04-21
安装mysql最后一步密码_MySQL安装最后一步无响应解决方法
2019-04-21
mysql modify语句格式_40条MySQL数据库语句格式
2019-04-21
mysql忽略大小写jpa解决_JPA 大小写敏感问题
2019-04-21