
Tensorflow2.0中的梯度带(GradientTape)、梯度更新以及部分报错的解决方法
发布日期:2021-05-06 23:42:59
浏览次数:40
分类:精选文章
本文共 4402 字,大约阅读时间需要 14 分钟。
TensorFlow GradientTape 组件详解
在TensorFlow中,GradientTape
是一个强大的工具,用于计算和管理梯度。它能够帮助开发者轻松地跟踪变量的变化,并在训练过程中自动计算梯度。GradientTape
的核心功能是记录梯度依赖关系,这对于复杂的机器学习模型(如神经网络)来说尤为重要。
GradientTape 的核心参数
GradientTape
类提供了几个关键参数和方法,主要有以下两个:
persistent(默认值:False
)
- 类型:
bool
- 功能:决定是否在求导之后图表被销毁。如果设置为
True
,则可以多次使用gradient
方法进行二次求导。 - 注意:在大多数情况下,默认值
False
已经足够,因为我们通常只需要一次梯度计算。
watch_accessed_variables(默认值:True
)
- 类型:
bool
- 功能:控制哪些可训练变量会被自动监视,默认情况下为
True
。如果设置为False
,则所有需要求导的变量必须手动使用watch
方法进行监视,否则会抛出错误。
使用 GradientTape 的示例
以下是一个简单的线性回归例子,展示如何使用 GradientTape
来计算和更新梯度:
import tensorflow as tfimport numpy as npTRAIN_STEPS = 20# Prepare training datatrain_X = np.linspace(-1, 1, 100)train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.33 + 10w = tf.Variable(initial_value=1.0)b = tf.Variable(initial_value=1.0)optimizer = tf.keras.optimizers.SGD(0.1)mse = tf.keras.losses.MeanSquaredError()print("Initial w:", w.numpy())print("Initial b:", b.numpy())for i in range(TRAIN_STEPS): print("Epoch:", i) with tf.GradientTape() as tape: logit = w * train_X + b loss = mse(train_Y, logit) gradients = tape.gradient(target=loss, sources=[w, b]) optimizer.apply_gradients(zip(gradients, [w, b])) print("Current w:", w.numpy()) print("Current b:", b.numpy())
注意事项
- 变量必须是可训练变量:在计算梯度时,变量必须使用
tf.Variable
定义,否则会报错。 - 使用
watch
方法:如果需要对特定变量进行梯度更新,可以手动使用tape.watch(variable)
。 - 多次使用
GradientTape
:当使用persistent=True
时,可以多次调用gradient
方法进行二次求导。
常见错误处理
在使用 GradientTape
时,可能会遇到以下错误:
TypeError: zip argument #2 must support iteration
- 出现原因:当只有一个变量需要更新时,忘记将其用列表包装。
- 修复方法:将
sources
参数改为[w]
或类似的列表形式。
TypeError: Cannot iterate over a scalar tensor
- 出现原因:同样是因为单个变量未正确包装。
- 修复方法:确保
sources
是一个列表,包含单个变量。
InvalidArgumentError: var and grad do not have the same shape
- 出现原因:变量和梯度的形状不匹配。
- 修复方法:确保使用正确的变量和梯度对。
高级使用技巧
在复杂模型中,GradientTape
的 persistent
参数非常有用。以下是一个使用 persistent=True
的示例:
def train_step(real_x, real_y): with tf.GradientTape(persistent=True) as tape: generated_y = generator_g(real_x, training=True) generated_x = generator_f(real_y, training=True) disc_real_x = discriminator_x(real_x, training=True) disc_real_y = discriminator_y(real_y, training=True) disc_fake_x = discriminator_x(generated_x, training=True) disc_fake_y = discriminator_y(generated_y, training=True) loss_gen_g = generator_loss(disc_fake_y) loss_gen_f = generator_loss(disc_fake_x) loss_disc_x = discriminator_loss(disc_real_x, disc_fake_y) loss_disc_y = discriminator_loss(disc_real_y, disc_fake_x) cycled_x = generator_f(generated_y, training=True) cycled_y = generator_g(generated_x, training=True) total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y) same_x = generator_f(real_x, training=True) same_y = generator_g(real_y, training=True) total_loss_gen_g = loss_gen_g + total_cycle_loss + identity_loss(real_y, same_y) total_loss_gen_f = loss_gen_f + total_cycle_loss + identity_loss(real_x, same_x) tape.watch(generated_g.trainable_variables()) tape.watch(generated_f.trainable_variables()) tape.watch(discriminator_x.trainable_variables()) tape.watch(discriminator_y.trainable_variables()) grad_gen_g = tape.gradient(total_loss_gen_g, generator_g.trainable_variables()) grad_gen_f = tape.gradient(total_loss_gen_f, generator_f.trainable_variables()) grad_disc_x = tape.gradient(loss_disc_x, discriminator_x.trainable_variables()) grad_disc_y = tape.gradient(loss_disc_y, discriminator_y.trainable_variables()) generator_g_optimizer.apply_gradients(zip(grad_gen_g, generator_g.trainable_variables())) generator_f_optimizer.apply_gradients(zip(grad_gen_f, generator_f.trainable_variables())) discriminator_x_optimizer.apply_gradients(zip(grad_disc_x, discriminator_x.trainable_variables())) discriminator_y_optimizer.apply_gradients(zip(grad_disc_y, discriminator_y.trainable_variables()))
注意事项
- 手动监视变量:如果
watch_accessed_variables
设置为False
,则需要手动使用tape.watch()
方法对需要更新的变量进行监视。 - 图表状态:使用
persistent=True
时,图表状态会被保留,可以多次调用gradient
方法。
总结
GradientTape
是TensorFlow中一个强大的工具,能够帮助开发者轻松地跟踪和计算梯度。在使用时,需要注意变量的定义、梯度的计算方式以及错误处理等关键点。通过合理使用 persistent
和 watch_accessed_variables
参数,可以有效地管理梯度计算和变量更新,提升模型训练效率。
发表评论
最新留言
网站不错 人气很旺了 加油
[***.192.178.218]2025年04月10日 13时05分01秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
惊闻NBC在奥运后放弃使用Silverlight
2019-03-06
IE下尚未实现错误的原因
2019-03-06
创建自己的Docker基础镜像
2019-03-06
HTTP 协议图解
2019-03-06
Python 简明教程 --- 20,Python 类中的属性与方法
2019-03-06
Python 简明教程 --- 21,Python 继承与多态
2019-03-06
KNN 算法-理论篇-如何给电影进行分类
2019-03-06
Spring Cloud第九篇 | 分布式服务跟踪Sleuth
2019-03-06
CODING 敏捷实战系列课第三讲:可视化业务分析
2019-03-06
使用 CODING DevOps 全自动部署 Hexo 到 K8S 集群
2019-03-06
工作动态尽在掌握 - 使用 CODING 度量团队效能
2019-03-06
CODING DevOps 代码质量实战系列最后一课,周四发车
2019-03-06
CODING DevOps 深度解析系列第二课报名倒计时!
2019-03-06
CODING DevOps 线下沙龙回顾二:SDK 测试最佳实践
2019-03-06
翻译:《实用的Python编程》03_01_Script
2019-03-06
数据结构第八节(图(下))
2019-03-06
基础篇:异步编程不会?我教你啊!CompletableFuture
2019-03-06
基于Mustache实现sql拼接
2019-03-06
气球游戏腾讯面试题滑动窗口解法
2019-03-06
POJ 2260 Error Correction 模拟 贪心 简单题
2019-03-06