LeNet网络模型的搭建与训练
发布日期:2022-09-10 02:40:19 浏览次数:2 分类:技术文章

本文共 4446 字,大约阅读时间需要 14 分钟。

1.网络结构的搭建

import torch.nn as nnimport torch.nn.functional as F# 继承nn.Module# 计算公式:N=(W-F+2P)/S+1,W为输入图片尺寸W×W,F为Filter大小,P为padding大小,S为stride# F=dilation×(kernel_size-1)+1class LeNet(nn.Module):    def __init__(self):        super(LeNet,self).__init__()        self.conv1=nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1,padding=0)        self.pool1=nn.MaxPool2d(kernel_size=2,stride=2)        self.conv2=nn.Conv2d(16,32,5)        self.pool2=nn.MaxPool2d(2,2)        self.fc1=nn.Linear(32*5*5,120)        self.fc2=nn.Linear(120,84)        self.fc3=nn.Linear(84,10)        def forward(self,x):        # input(3,32,32) F=1×(5-1)+1,N=(32-5+2×0)/1+1=28,output(16,28,28)        x=F.relu(self.conv1(x))         x=self.pool1(x) # output(16,14,14)        # N=(14-5)/1+1=10,output(16,10,10)        x=F.relu(self.conv2(x))        x=self.pool2(x) # output(32,5,5)        x=x.view(-1,32*5*5) # output(32*5*5)        x=F.relu(self.fc1(x)) # output(120)        x=F.relu(self.fc2(x)) # output(84)        x=self.fc3(x) # output(10)        return x

2.模型的训练

import torchimport torchvisionimport torch.nn as nnfrom model import LeNetimport torch.optim as optimimport torchvision.transforms as transformsimport matplotlib.pyplot as  pltimport numpy as npdef main():    # 图片处理    transform = transforms.Compose([        transforms.ToTensor(),        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))    ])    # CIFAR10数据集,50000张训练图片    # 第一次下载使用时要将download设置为true才能自动去下载数据集    train_set = torchvision.datasets.CIFAR10(        root='./data',        train=True,        download=False,        transform=transform    )    train_loader = torch.utils.data.DataLoader(        train_set,         batch_size=36,        shuffle=True,         num_workers=0    )    # 10000张验证图片    val_set = torchvision.datasets.CIFAR10(        root='./data',        train=False,        download=False,        transform=transform    )    val_loader = torch.utils.data.DataLoader(        val_set,        batch_size=10000,        shuffle=False,        num_workers=0    )    val_data_iter = iter(val_loader)    val_image,val_label = val_data_iter.next()        # classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')    # def imshow(img):    #     img =img/2+0.5    #     np_img= img.numpy()    #     plt.imshow(np.transpose(np_img,(1,2,0)))    #     plt.show()    # # 打印标签    # print(' '.join('%5s'% classes[val_label[j]] for j in range(4)))    # # 展示图片    # imshow(torchvision.utils.make_grid(val_image))    net = LeNet()    loss_function = nn.CrossEntropyLoss()    optimizer = optim.Adam(net.parameters(),lr=0.001)    for epoch in range(5):        running_loss = 0.0        for step,data in enumerate(train_loader,start=0):            # get the inputs:data is a list of[inputs,labels]            inputs,labels = data            # zero the parameter gradients            optimizer.zero_grad()            # forward+backward+optimize            outputs = net(inputs)             loss = loss_function(outputs,labels)            loss.backward()            optimizer.step()            # print statistics            running_loss+=loss.item()            if step%500 == 499:# print every 500  mini-batches                with torch.no_grad():                    outputs = net(val_image)#[batch,10]                    predict_y = torch.max(outputs,dim=1)[1]                    accuracy = torch.eq(predict_y,val_label).sum().item()/val_label.size(0)                    print('[%d,%5d] train_loss:%.3f test_accuracy:%.3f'%(epoch+1,step+1,running_loss/500,accuracy))                    running_loss = 0.0    print("Finished Training!!!")    save_path = './Lenet.pth'    torch.save(net.state_dict(),save_path)if __name__ =='__main__':    main()

3.模型的预测检验

import torchimport torchvision.transforms as transformsfrom PIL import Imagefrom model import LeNetdef main():    transform = transforms.Compose([        transforms.ToTensor(),        transforms.Resize((32,32)),        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))    ])    classes=('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')    net=LeNet()    net.load_state_dict(torch.load('Lenet.pth'))    im = Image.open('1.jpg')    im = transform(im) # [C,H,W]    im = torch.unsqueeze(im,dim=0) # [N,C,H,W]    with torch.no_grad():        outputs = net(im)        predict = torch.max(outputs,dim=1)[1].numpy()    print(classes[int(predict)])if __name__ == '__main__':    main()

 

 

转载地址:https://blog.csdn.net/YYSTINTERNET/article/details/124911365 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:length,length(),size()
下一篇:LeNet模型及代码详解

发表评论

最新留言

逛到本站,mark一下
[***.202.152.39]2024年02月29日 20时38分23秒

关于作者

    喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!

推荐文章

java crc32 使用_Java CRC32的用法 2019-04-21
java读取unicode_java怎么样将unicode解码读取?Java读取本地文件进 2019-04-21
java.io.file()_Java File getUsableSpace()方法 2019-04-21
java httpclient 工具_spring整合httpClient工具类 2019-04-21
java监控其他服务器运行状态_windows服务器监控多个tomcat运行状态 2019-04-21
java给学生按总成绩排名_java - 输入学生成绩,取它们的平均值,然后通过排名等级的学生 - SO中文参考 - www.soinside.com... 2019-04-21
java构造函数有什么用_java构造函数有什么用,怎么用 2019-04-21
mysql 匹配 隔开的_按空格分隔关键字并搜索MySQL数据库 2019-04-21
java factory用法_怎样使用Java实现Factory设计模式 2019-04-21
java窗口内容如何复制_求助Java窗口菜单如何实现复制粘贴剪切等功能(内附源代码)... 2019-04-21
盾神与砝码称重java_[蓝桥杯][算法提高VIP]盾神与砝码称重 2019-04-21
java输出狗的各类信息_第九章Java输入输出操作 2019-04-21
java notify怎么用_java 如何使用notify() 2019-04-21
java加载指定文件为当前文本,java:如何使用bufferedreader读取特定的行 2019-04-21
java metrics 怎么样,Java metrics 2019-04-21
在vscode中php语言配置,Visual Studio Code C / C++ 语言环境配置 2019-04-21
php怎么翻译数据库中的中文,javascript – 如何将翻译后的文本插入数据库php 2019-04-21
普朗克公式matlab,用MATLAB实现普朗克函数积分的快捷计算.pdf 2019-04-21
swoolec+%3c?php,PHP+Swoole并发编程的魅力 2019-04-21
php 404配置,phpcms如何配置404 2019-04-21