LeNet网络模型的搭建与训练
发布日期:2022-09-10 02:40:19
浏览次数:2
分类:技术文章
本文共 4446 字,大约阅读时间需要 14 分钟。
1.网络结构的搭建
import torch.nn as nnimport torch.nn.functional as F# 继承nn.Module# 计算公式:N=(W-F+2P)/S+1,W为输入图片尺寸W×W,F为Filter大小,P为padding大小,S为stride# F=dilation×(kernel_size-1)+1class LeNet(nn.Module): def __init__(self): super(LeNet,self).__init__() self.conv1=nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1,padding=0) self.pool1=nn.MaxPool2d(kernel_size=2,stride=2) self.conv2=nn.Conv2d(16,32,5) self.pool2=nn.MaxPool2d(2,2) self.fc1=nn.Linear(32*5*5,120) self.fc2=nn.Linear(120,84) self.fc3=nn.Linear(84,10) def forward(self,x): # input(3,32,32) F=1×(5-1)+1,N=(32-5+2×0)/1+1=28,output(16,28,28) x=F.relu(self.conv1(x)) x=self.pool1(x) # output(16,14,14) # N=(14-5)/1+1=10,output(16,10,10) x=F.relu(self.conv2(x)) x=self.pool2(x) # output(32,5,5) x=x.view(-1,32*5*5) # output(32*5*5) x=F.relu(self.fc1(x)) # output(120) x=F.relu(self.fc2(x)) # output(84) x=self.fc3(x) # output(10) return x
2.模型的训练
import torchimport torchvisionimport torch.nn as nnfrom model import LeNetimport torch.optim as optimimport torchvision.transforms as transformsimport matplotlib.pyplot as pltimport numpy as npdef main(): # 图片处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) ]) # CIFAR10数据集,50000张训练图片 # 第一次下载使用时要将download设置为true才能自动去下载数据集 train_set = torchvision.datasets.CIFAR10( root='./data', train=True, download=False, transform=transform ) train_loader = torch.utils.data.DataLoader( train_set, batch_size=36, shuffle=True, num_workers=0 ) # 10000张验证图片 val_set = torchvision.datasets.CIFAR10( root='./data', train=False, download=False, transform=transform ) val_loader = torch.utils.data.DataLoader( val_set, batch_size=10000, shuffle=False, num_workers=0 ) val_data_iter = iter(val_loader) val_image,val_label = val_data_iter.next() # classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck') # def imshow(img): # img =img/2+0.5 # np_img= img.numpy() # plt.imshow(np.transpose(np_img,(1,2,0))) # plt.show() # # 打印标签 # print(' '.join('%5s'% classes[val_label[j]] for j in range(4))) # # 展示图片 # imshow(torchvision.utils.make_grid(val_image)) net = LeNet() loss_function = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(),lr=0.001) for epoch in range(5): running_loss = 0.0 for step,data in enumerate(train_loader,start=0): # get the inputs:data is a list of[inputs,labels] inputs,labels = data # zero the parameter gradients optimizer.zero_grad() # forward+backward+optimize outputs = net(inputs) loss = loss_function(outputs,labels) loss.backward() optimizer.step() # print statistics running_loss+=loss.item() if step%500 == 499:# print every 500 mini-batches with torch.no_grad(): outputs = net(val_image)#[batch,10] predict_y = torch.max(outputs,dim=1)[1] accuracy = torch.eq(predict_y,val_label).sum().item()/val_label.size(0) print('[%d,%5d] train_loss:%.3f test_accuracy:%.3f'%(epoch+1,step+1,running_loss/500,accuracy)) running_loss = 0.0 print("Finished Training!!!") save_path = './Lenet.pth' torch.save(net.state_dict(),save_path)if __name__ =='__main__': main()
3.模型的预测检验
import torchimport torchvision.transforms as transformsfrom PIL import Imagefrom model import LeNetdef main(): transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((32,32)), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) ]) classes=('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck') net=LeNet() net.load_state_dict(torch.load('Lenet.pth')) im = Image.open('1.jpg') im = transform(im) # [C,H,W] im = torch.unsqueeze(im,dim=0) # [N,C,H,W] with torch.no_grad(): outputs = net(im) predict = torch.max(outputs,dim=1)[1].numpy() print(classes[int(predict)])if __name__ == '__main__': main()
转载地址:https://blog.csdn.net/YYSTINTERNET/article/details/124911365 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
逛到本站,mark一下
[***.202.152.39]2024年02月29日 20时38分23秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
java crc32 使用_Java CRC32的用法
2019-04-21
java.io.file()_Java File getUsableSpace()方法
2019-04-21
java httpclient 工具_spring整合httpClient工具类
2019-04-21
java监控其他服务器运行状态_windows服务器监控多个tomcat运行状态
2019-04-21
java构造函数有什么用_java构造函数有什么用,怎么用
2019-04-21
mysql 匹配 隔开的_按空格分隔关键字并搜索MySQL数据库
2019-04-21
java factory用法_怎样使用Java实现Factory设计模式
2019-04-21
盾神与砝码称重java_[蓝桥杯][算法提高VIP]盾神与砝码称重
2019-04-21
java输出狗的各类信息_第九章Java输入输出操作
2019-04-21
java notify怎么用_java 如何使用notify()
2019-04-21
java metrics 怎么样,Java metrics
2019-04-21
普朗克公式matlab,用MATLAB实现普朗克函数积分的快捷计算.pdf
2019-04-21
swoolec+%3c?php,PHP+Swoole并发编程的魅力
2019-04-21
php 404配置,phpcms如何配置404
2019-04-21