
机器学习11:pytorch训练自定义数据集简单示例
发布日期:2021-05-10 22:30:28
浏览次数:22
分类:精选文章
本文共 3387 字,大约阅读时间需要 11 分钟。
PyTorch 训练自定义数据集简易分类器
环境配置
确保 PyTorch 环境配置正确,安装所需的主库包:
import torchimport torchvisionfrom torchvision import transformsimport torch.nn as nnfrom torch.autograd import Variableimport torch.optim as optimimport matplotlib.pyplot as pltimport numpy as np
数据处理
训练集处理
使用 ImageFolder
加载训练数据夹中的所有图片:
def loadtraindata(): path = "/path/to/train" trainset = torchvision.datasets.ImageFolder(path, transform=transforms.Compose([ transforms.Resize((32, 32)), transforms.CenterCrop(32), transforms.ToTensor() ])) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) return trainloader
测试集处理
类似地处理测试数据集:
def loadtestdata(): path = "/path/to/test" testset = torchvision.datasets.ImageFolder(path, transform=transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor() ])) testloader = torch.utils.data.DataLoader(testset, batch_size=25, shuffle=True, num_workers=2) return testloader
网络结构
定义一个简单的卷积神经网络:
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
训练过程
定义训练函数,完整实现训练过程:
def trainandsave(): trainloader = loadtraindata() net = Net() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) criterion = nn.CrossEntropyLoss() for epoch in range(5): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = Variable(data[0]), Variable(data[1]) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 200 == 199: print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 200:.3f}') running_loss = 0.0 print('Finished Training') torch.save(net, 'net.pkl') torch.save(net.state_dict(), 'net_params.pkl')
测试验证
定义测试函数,验证网络性能:
def reload_net(): return torch.load('net.pkl')def imshow(img): img = img / 2 + 0.5 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show()def test(): testloader = loadtestdata() net = reload_net() dataiter = iter(testloader) images, labels = dataiter.next() grid_img = torchvision.utils.make_grid(images, nrow=5) imshow(grid_img) classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] print(f'GroundTruth: {[" ".join("%s" % classes[labels[j]]) for j in range(25)]}') outputs = net(Variable(images)) _, predicted = torch.max(outputs.data, 1) print(f'Predicted: {[" ".join("%s" % classes[predicted[j]]) for j in range(25)]}')
模型训练总结
通过以上步骤,您已经成功配置并训练了一个使用 PyTorch 实现的手写数字分类器。网络通过卷积层和全连接层完成图像特征提取与分类,输出为 0-9 数字对应的分类结果。模型参数保存在 net.pkl
和 net_params.pkl
文件中,可加载到其他环境中继续使用。
发表评论
最新留言
留言是一种美德,欢迎回访!
[***.207.175.100]2025年05月02日 20时57分19秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
输出对象的值——踩坑
2019-03-15
angular2项目里使用排他思想
2019-03-15
折线图上放面积并隐藏XY轴的线
2019-03-15
zabbix之自动发现
2019-03-15
Experience of tecent interview
2019-03-15
linux 与win共享文件夹
2019-03-15
Linux管理员权限失败su Authentication failure
2019-03-15
EduCoder _Web实训作业---:Web前端开发相关的概念
2019-03-15
python实验-太理
2019-03-15
python实验--太理二
2019-03-15
failed to push some refs to git
2019-03-15