
PyTorch学习笔记——(5)手动实现线性回归 和 利用pytorch实现线性回归
前向传播:通过模型得到预测值 计算损失:使用均方误差函数计算预测值与真实值之间的误差。 反向传播:计算误差对权重和偏置的梯度。 更新参数:使用优化器更新模型参数。 梯度清零:在每次训练前,必须清零梯度,防止梯度叠加。 导入库:使用 数据准备:生成训练数据和真实标签。 模型定义:定义线性回归模型,使用 损失函数和优化器:选择均方误差和 SGD。 训练过程:包含前向传播、损失计算、反向传播和参数更新。 评估和可视化:测试模型,并绘制预测值和真实值的分布图。
发布日期:2021-05-15 00:34:17
浏览次数:19
分类:精选文章
本文共 1677 字,大约阅读时间需要 5 分钟。
PyTorch 实现线性回归
在线性回归问题中,我们的目标是根据给定的特征数据和标签数据,训练出一个线性模型,使其能够预测目标值。以下将详细介绍如何使用 PyTorch 从头实现一个简单的线性回归模型。
一、准备数据
假设我们有以下关系:( y = 3x + 0.8 )。我们将生成 500 个样本数据,其中特征 x 从均匀分布中随机选取。
import torchimport matplotlib.pyplot as plt# 生成数据x = torch.rand([500, 1]) # 特征矩阵,500个样本,1列y_true = 3 * x + 0.8 # true labels
二、定义模型和优化器
在 PyTorch 中,模型通常使用 nn.Module
来定义。我们定义一个简单的线性网络,接收输入 x,输出预测值。
class Linear(nn.Module): def __init__(self): super(Linear, self).__init__() self.linear = nn.Linear(1, 1) # 输入一维,输出一维 def forward(self, x): out = self.linear(x) return outmodel = Linear()
选择一个合适的优化器,这里使用随机梯度下降(SGD)优化器,学习率设置为 0.01。
criterion = nn.MSELoss() # 误差函数optimizer = optim.SGD(model.parameters(), lr=0.01)
三、训练模型
训练模型的过程包括以下几个步骤:
y_predict
。训练过程如下:
for epoch in range(5000): # 前向传播 y_predict = model(x) # 计算损失 loss = criterion(y_predict, y_true) # 清零梯度 optimizer.zero_grad() # 反向传播 loss.backward() # 更新参数 optimizer.step() # 输出 statistics if (epoch + 1) % 100 == 0: print(f'Epoch: {epoch+1}, Loss: {loss.item():.4f}')
四、评估模型
在模型训练结束后,可以评估模型的预测能力。评估模式下,模型不会更新梯度,这样更合适用于测试。
model.eval()predict = model(x).data.numpy()plt.scatter(x.data.numpy(), y_true.data.numpy(), c='r')plt.plot(x.data.numpy(), predict)plt.show()
五、注意事项
六、代码解释总结
torch
和 matplotlib
进行数据处理和可视化。nn.Module
。通过以上步骤,我们就成功实现了一个简单的线性回归模型,能够预测给定的目标值。可以根据需要调整模型复杂度和优化参数,提高模型性能。
发表评论
最新留言
表示我来过!
[***.240.166.169]2025年05月05日 07时29分06秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
16 python基础-恺撒密码
2019-03-11
06.1 python基础--结构控制
2019-03-11
Frame--Api框架
2019-03-11
idea 在Debug 模式中运行语句中函数的方法
2019-03-11
springboot2.1.1开启druid数据库连接池并开启监控
2019-03-11
《朝花夕拾》金句摘抄(五)
2019-03-11
Boostrap技能点整理之【网格系统】
2019-03-11
新闻发布项目——业务逻辑层(UserService)
2019-03-11
hibernate正向生成数据库表以及配置——hibernate.cfg.xml
2019-03-11
javaWeb服务详解(含源代码,测试通过,注释) ——Emp的Dao层
2019-03-11
java实现人脸识别源码【含测试效果图】——Dao层(IUserDao)
2019-03-11
使用ueditor实现多图片上传案例——前台数据层(Index.jsp)
2019-03-11
ssm(Spring+Spring mvc+mybatis)——saveDept.jsp
2019-03-11
JavaScript操作BOM对象
2019-03-11