机器学习11:pytorch训练自定义数据集简单示例
发布日期:2021-05-10 22:30:28 浏览次数:22 分类:精选文章

本文共 3387 字,大约阅读时间需要 11 分钟。

PyTorch 训练自定义数据集简易分类器

环境配置

确保 PyTorch 环境配置正确,安装所需的主库包:

import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

数据处理

训练集处理

使用 ImageFolder 加载训练数据夹中的所有图片:

def loadtraindata():
path = "/path/to/train"
trainset = torchvision.datasets.ImageFolder(path, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.CenterCrop(32),
transforms.ToTensor()
]))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
return trainloader

测试集处理

类似地处理测试数据集:

def loadtestdata():
path = "/path/to/test"
testset = torchvision.datasets.ImageFolder(path, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]))
testloader = torch.utils.data.DataLoader(testset, batch_size=25, shuffle=True, num_workers=2)
return testloader

网络结构

定义一个简单的卷积神经网络:

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

训练过程

定义训练函数,完整实现训练过程:

def trainandsave():
trainloader = loadtraindata()
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
for epoch in range(5):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = Variable(data[0]), Variable(data[1])
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 200 == 199:
print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 200:.3f}')
running_loss = 0.0
print('Finished Training')
torch.save(net, 'net.pkl')
torch.save(net.state_dict(), 'net_params.pkl')

测试验证

定义测试函数,验证网络性能:

def reload_net():
return torch.load('net.pkl')
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
def test():
testloader = loadtestdata()
net = reload_net()
dataiter = iter(testloader)
images, labels = dataiter.next()
grid_img = torchvision.utils.make_grid(images, nrow=5)
imshow(grid_img)
classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
print(f'GroundTruth: {[" ".join("%s" % classes[labels[j]]) for j in range(25)]}')
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data, 1)
print(f'Predicted: {[" ".join("%s" % classes[predicted[j]]) for j in range(25)]}')

模型训练总结

通过以上步骤,您已经成功配置并训练了一个使用 PyTorch 实现的手写数字分类器。网络通过卷积层和全连接层完成图像特征提取与分类,输出为 0-9 数字对应的分类结果。模型参数保存在 net.pklnet_params.pkl 文件中,可加载到其他环境中继续使用。

上一篇:机器学习12:pytorch中transforms的22个方式【转载】
下一篇:机器学习10:如何理解随机梯度下降

发表评论

最新留言

留言是一种美德,欢迎回访!
[***.207.175.100]2025年05月02日 20时57分19秒

关于作者

    喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!

推荐文章