
使用Pytorch框架实现简单的数据分类(7)
发布日期:2021-05-09 12:07:39
浏览次数:28
分类:精选文章
本文共 2281 字,大约阅读时间需要 7 分钟。
PyTorch基础训练代码与结果分析
代码中使用的函数简要介绍
在PyTorch中,常用到的以下函数和功能:
- Torch.normal:用于生成从相互独立的正态分布中随机生成的张量。
- Torch.cat:将两个张量( tensor )拼接在一起,支持沿不同维度进行拼接。
代码内容
以下是基于PyTorch实现的简单分类模型训练代码:
import torchfrom torch.autograd import Variableimport torch.nn.functional as Fimport matplotlib.pyplot as plt# 数据生成n_data = torch.ones(100, 2) # 类0样本,形状为(100, 2)x0 = torch.normal(2 * n_data, 1) # 类0输入数据,形状为(100, 2)y0 = torch.zeros(100) # 类0标签,形状为(100, 1)x1 = torch.normal(-2 * n_data, 1) # 类1输入数据,形状为(100, 2)y1 = torch.ones(100) # 类1标签,形状为(100, 1)# 数据拼接x = torch.cat((x0, x1), 0).type(torch.FloatTensor) # 形状为(200, 2)y = torch.cat((y0, y1), ).type(torch.LongTensor) # 形状为(200, )# 模型定义class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) # 隐藏层 self.out = torch.nn.Linear(n_hidden, n_output) # 输出层 def forward(self, x): x = F.relu(self.hidden(x)) # 激活函数为ReLU x = self.out(x) return xnet = Net(n_feature=2, n_hidden=10, n_output=2) # 定义网络print(net) # 网络结构# 优化器和损失函数optimizer = torch.optim.SGD(net.parameters(), lr=0.02)loss_func = torch.nn.CrossEntropyLoss()# 训练过程plt.ion() # 打开图表for t in range(1000): out = net(x) # 前向传播,预测 loss = loss_func(out, y) # 计算损失 optimizer.zero_grad() # 清除梯度 loss.backward() # 反向传播 optimizer.step() # 应用梯度更新 if t % 2 == 0: plt.cla() # 清理图表 prediction = torch.max(out, 1)[1] # 获取预测标签 pred_y = prediction.data.numpy() # 转换为numpy数组 target_y = y.data.numpy() # 获取真实标签 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn') accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size) plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'}) plt.pause(1)plt.ioff()plt.show()
代码运行结果
通过代码运行可以观察到以下结果:
- 代码生成了两个类别(类0和类1)的数据集,形状分别为(100, 2)和(100, 1),并通过
torch.cat
函数将它们拼接为(200, 2)的输入数据和(200,)的标签数据。 - 定义了一个简单的三层感知机模型,包含隐藏层和输出层。
- 通过
torch.optim.SGD
优化器和CrossEntropyLoss
函数进行训练。 - 代码每隔两次迭代绘制一次训练图表,并显示分类准确率。
训练过程中,图表展示了分类准确率的变化趋势,随着训练次数的增加,准确率逐步提升,最终趋于收敛。
课程推荐
关注我们的公众号,获取更多关于《计算机视觉与图形学》
相关知识和课程资料!
发表评论
最新留言
第一次来,支持一个
[***.219.124.196]2025年04月14日 16时31分03秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
面试题 08.01. 三步问题
2019-03-15
剑指 Offer 11. 旋转数组的最小数字
2019-03-15
剑指 Offer 57. 和为s的两个数字
2019-03-15
git 在本地删除、添加远端的源
2019-03-15
字符串的反转
2019-03-15
docker用法
2019-03-15
word文档注入(追踪word文档)未完
2019-03-15
作为我的第一篇csdn博客吧
2019-03-15
Linux Ubuntu 用命令安装MySql
2019-03-15
java中简单实现栈
2019-03-15
ajax异步提交失败
2019-03-15
查看安卓系统是否卡开了可调试debuggable
2019-03-15
一道简单的访问越界、栈溢出pwn解题记录
2019-03-15
ubuntu18.04.4版本安装docker教程
2019-03-15
嵌入式day17
2019-03-15
Java基础编程
2019-03-15
STS 的共享内存过程(待充分理解)
2019-03-15
CreatePointFont使用方法
2019-03-15