
使用Pytorch框架实现简单的数据分类(7)
发布日期:2021-05-09 12:07:39
浏览次数:3
分类:技术文章
本文共 2645 字,大约阅读时间需要 8 分钟。
(1)代码中使用的函数简要介绍
torch.normal #张量里面的随机数是从相互独立的正态分布中随机生成的。torch.cat #将两个张量(tensor)拼接在一起
(2)代码
import torchfrom torch.autograd import Variableimport torch.nn.functional as Fimport matplotlib.pyplot as plt# torch.manual_seed(1) # reproducible#制造训练数据n_data = torch.ones(100, 2)x0 = torch.normal(2*n_data, 1) # class0 x data (tensor), shape=(100, 2)y0 = torch.zeros(100) # class0 y data (tensor), shape=(100, 1)x1 = torch.normal(-2*n_data, 1) # class1 x data (tensor), shape=(100, 2)y1 = torch.ones(100) # class1 y data (tensor), shape=(100, 1)x = torch.cat((x0, x1), 0).type(torch.FloatTensor) # shape (200, 2) FloatTensor = 32-bit floatingy = torch.cat((y0, y1), ).type(torch.LongTensor) # shape (200,) LongTensor = 64-bit integer'''# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensorsx, y = Variable(x), Variable(y)plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')plt.show()'''class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer self.out = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x): x = F.relu(self.hidden(x)) # activation function for hidden layer x = self.out(x) return xnet = Net(n_feature=2, n_hidden=10, n_output=2) # define the networkprint(net) # net architectureoptimizer = torch.optim.SGD(net.parameters(), lr=0.02)loss_func = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hottedplt.ion() # something about plottingfor t in range(1000): out = net(x) # input x and predict based on x loss = loss_func(out, y) # must be (1. nn output, 2. target), the target label is NOT one-hotted optimizer.zero_grad() # clear gradients for next train loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients if t % 2 == 0: # plot and show learning process plt.cla() prediction = torch.max(out, 1)[1] pred_y = prediction.data.numpy() target_y = y.data.numpy() plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn') accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size) plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'}) plt.pause(1)plt.ioff()plt.show()
(3)代码运行结果
注:本文中代码主要参考链接:
了解更多关于《计算机视觉与图形学》相关知识,请关注公众号:
下载我们视频中代码和相关讲义,请在公众号回复:计算机视觉课程资料
转载地址:https://blog.csdn.net/CSS360/article/details/88386890 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
哈哈,博客排版真的漂亮呢~
[***.90.31.176]2023年09月27日 07时56分43秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
c++——代码区、常量区、静态区、堆、栈
2019-03-08
c++——sizeof和strlen的区别
2019-03-08
shell基础02——命令别名与常用快捷键
2019-03-08
go(基础07)——组合取代继承
2019-03-08
操作系统(三) --CPU管理与多进程图像
2019-03-08
操作系统(四) -- 用户级线程与核心级线程(线程的切换)
2019-03-08
go(基础09)——defer
2019-03-08
操作系统(七) -- 死锁
2019-03-08
操作系统(十一)——文件系统(二)
2019-03-08
操作系统(十二)——文件系统(三)
2019-03-08
操作系统(十三)——CPU Cache
2019-03-08
linux基础07——echo命令
2019-03-08
操作系统(十四)——CPU执行原理
2019-03-08
操作系统(十五)——mmap
2019-03-08
C++ Primer Plus 第四章 复合类型(二)
2019-03-08
实现嵌入式sql的方法之proc编程
2019-03-08
Mysql中的中文乱码问题
2019-03-08
QT基础知识
2019-03-08
QT基本知识(第二天)
2019-03-08
QT基础知识-画图和文件(第三天)
2019-03-08