使用Pytorch框架实现简单的数据分类(7)
发布日期:2021-05-09 12:07:39 浏览次数:28 分类:精选文章

本文共 2281 字,大约阅读时间需要 7 分钟。

PyTorch基础训练代码与结果分析

代码中使用的函数简要介绍

在PyTorch中,常用到的以下函数和功能:

  • Torch.normal:用于生成从相互独立的正态分布中随机生成的张量。
  • Torch.cat:将两个张量( tensor )拼接在一起,支持沿不同维度进行拼接。

代码内容

以下是基于PyTorch实现的简单分类模型训练代码:

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
# 数据生成
n_data = torch.ones(100, 2) # 类0样本,形状为(100, 2)
x0 = torch.normal(2 * n_data, 1) # 类0输入数据,形状为(100, 2)
y0 = torch.zeros(100) # 类0标签,形状为(100, 1)
x1 = torch.normal(-2 * n_data, 1) # 类1输入数据,形状为(100, 2)
y1 = torch.ones(100) # 类1标签,形状为(100, 1)
# 数据拼接
x = torch.cat((x0, x1), 0).type(torch.FloatTensor) # 形状为(200, 2)
y = torch.cat((y0, y1), ).type(torch.LongTensor) # 形状为(200, )
# 模型定义
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # 隐藏层
self.out = torch.nn.Linear(n_hidden, n_output) # 输出层
def forward(self, x):
x = F.relu(self.hidden(x)) # 激活函数为ReLU
x = self.out(x)
return x
net = Net(n_feature=2, n_hidden=10, n_output=2) # 定义网络
print(net) # 网络结构
# 优化器和损失函数
optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
loss_func = torch.nn.CrossEntropyLoss()
# 训练过程
plt.ion() # 打开图表
for t in range(1000):
out = net(x) # 前向传播,预测
loss = loss_func(out, y) # 计算损失
optimizer.zero_grad() # 清除梯度
loss.backward() # 反向传播
optimizer.step() # 应用梯度更新
if t % 2 == 0:
plt.cla() # 清理图表
prediction = torch.max(out, 1)[1] # 获取预测标签
pred_y = prediction.data.numpy() # 转换为numpy数组
target_y = y.data.numpy() # 获取真实标签
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1],
c=pred_y, s=100, lw=0, cmap='RdYlGn')
accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy,
fontdict={'size': 20, 'color': 'red'})
plt.pause(1)
plt.ioff()
plt.show()

代码运行结果

通过代码运行可以观察到以下结果:

  • 代码生成了两个类别(类0和类1)的数据集,形状分别为(100, 2)和(100, 1),并通过torch.cat函数将它们拼接为(200, 2)的输入数据和(200,)的标签数据。
  • 定义了一个简单的三层感知机模型,包含隐藏层和输出层。
  • 通过torch.optim.SGD优化器和CrossEntropyLoss函数进行训练。
  • 代码每隔两次迭代绘制一次训练图表,并显示分类准确率。

训练过程中,图表展示了分类准确率的变化趋势,随着训练次数的增加,准确率逐步提升,最终趋于收敛。

课程推荐

关注我们的公众号,获取更多关于《计算机视觉与图形学》相关知识和课程资料!

上一篇:Pytorch保存训练好的模型以及参数(8)
下一篇:Pytorch实现线性回归并实时显示拟合过程(6)

发表评论

最新留言

第一次来,支持一个
[***.219.124.196]2025年04月14日 16时31分03秒