hook函数与CAM可视化
发布日期:2021-05-14 14:42:16 浏览次数:15 分类:精选文章

本文共 2617 字,大约阅读时间需要 8 分钟。

hook函数与CAM可视化

一 hook_fmap_vis.py

介绍 本节将通过hook函数实现特征图的可视化,将详细介绍hook函数在特征图分析中的应用。

代码解析

import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
# 正常化参数
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
# 数据预处理
img_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(normMean, normStd)
])
# 读取图片
path_img = "./lena.png"
img_pil = Image.open(path_img).convert('RGB')
# 生成特征图
img_tensor = img_transforms(img_pil)
img_tensor.unsqueeze_(0)
# 预训练模型
alexnet = models.alexnet(pretrained=True)
# 初始化hook函数
fmap_dict = {}
for name, sub_module in alexnet.named_modules():
if isinstance(sub_module, nn.Conv2d):
key_name = str(sub_module.weight.shape)
fmap_dict.setdefault(key_name, list())
def hook_func(m, i, o):
key_name = str(m.weight.shape)
fmap_dict[key_name].append(o)
alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)
# 前向传播
output = alexnet(img_tensor)
# 可视化特征图
for layer_name, fmap_list in fmap_dict.items():
fmap = fmap_list[0]
fmap.transpose_(0, 1)
nrow = int(np.sqrt(fmap.shape[0]))
fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)

功能说明

  • 读取图片并进行标准化转换
  • 初始化预训练AlexNet模型
  • 通过hook函数注册每个卷积层的特征图输出
  • 打印并可视化不同层的特征图
  • 依赖环境

    • PyTorch
    • TensorBoard
    • PIL库
    • OpenCV

    二 hook_methods.py

    介绍 本节将介绍如何利用hook函数进行张量分析,包括Tensor hooks的应用和注册回溯hook函数的示例。

    代码解析

    import torch
    import torch.nn as nn
    from tools.common_tools import set_seed
    set_seed(1) # 设置随机种子
    # 1.Tensor hook实例
    flag = 0
    # flag = 1 可以经过门控开启
    if flag:
    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)
    a_grad = list()
    def grad_hook(grad):
    grad *= 2
    return grad * 3
    handle = w.register_hook(grad_hook)
    y.backward()
    print("w.grad: ", w.grad)
    handle.remove()

    功能说明

  • 定义一个有易 gradient hooks 的函数
  • 创立一个 关键词 seed字符串 的 Tensorboard SummaryWriter 对象
  • 执行网络推理并丢弃反向传播
  • 查看 w_tensor 的梯度变化
  • 代码注释

    • register_hook 方法允许自定义的梯度计算函数替换原始的。
    • SummaryWriter 用于以可视化格式记录实验结果。
    • grad_hook 函数会在每个反向传播步骤调用,修改梯度。

    注意事项

    • 每次运行前需确保iendo update callback 函数已被正确注册。
    • 带有 requires_grad=True 的张量才会生成梯度。你可以根据需要添加更多张量。
    • 建议不要在嵌套层级过多中使用 register hook,可能会导致递归错误。
    上一篇:人工智能与信息社会
    下一篇:Tensorboard使用

    发表评论

    最新留言

    留言是一种美德,欢迎回访!
    [***.207.175.100]2025年04月28日 23时30分33秒