
查看卷积网络每一层的feature map的代码
发布日期:2021-05-10 14:16:22
浏览次数:24
分类:精选文章
本文共 4201 字,大约阅读时间需要 14 分钟。
import osimport sysimport pdbimport loggingimport timeimport torchimport argparseimport numpy as npimport torch.nn as nnimport torch.nn.functional as Ffrom collections import OrderedDictimport options.options as optionimport utils.util as utilfrom data.util import bgr2ycbcrfrom data import create_dataset, create_dataloaderfrom models import create_modelfrom models.modules import block as Bimport matplotlib.pyplot as plt# optionsopt = 'options/test/test_sr.json'opt = option.parse(opt, is_train=False)util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))opt = option.dict_to_nonedict(opt)# Create test dataset and dataloadertest_loaders = []for phase, dataset_opt in sorted(opt['datasets'].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt) print('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) test_loaders.append(test_loader)# Create modelmodel = create_model(opt)# Register hook for featrue map.def save_feature(name): def hook(module, input, output): featuremap[name] = output return hook# Set register hookconv_idx = 0featuremap = OrderedDict()for m in model.netG.module.model.modules(): if m._get_name() == '********': conv_idx += 1 m.register_forward_hook(save_feature('conv_' + str(conv_idx)))#可在module前向传播或反向传播时注册钩子# print(conv_idx)# exit()for test_loader in test_loaders: test_set_name = test_loader.dataset.opt['name'] test_start_time = time.time() dataset_dir = os.path.join(opt['path']['results_root'], test_set_name) util.mkdir(dataset_dir) test_results = OrderedDict() test_results['psnr'] = [] test_results['ssim'] = [] test_results['psnr_y'] = [] test_results['ssim_y'] = [] for data in test_loader: need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True model.feed_data(data, need_HR=need_HR) img_path = data['LR_path'][0] img_name = os.path.splitext(os.path.basename(img_path))[0] model.test() # test visuals = model.get_current_visuals(need_HR=need_HR) sr_img = util.tensor2img(visuals['SR']) # uint8gwp=0from mpl_toolkits.axes_grid1 import AxesGridfor k,v in featuremap.items(): # print(v[1].shape) # exit() vals = v[0].squeeze().float().cpu().numpy() print(vals.shape) # exit() fig = plt.figure(figsize=(15,5)) grid = AxesGrid(fig, 111, nrows_ncols=(2, 16), axes_pad=0.05, share_all=True, label_mode="L", cbar_location="right", cbar_mode="single", ) for val, ax in zip(vals,grid): im = ax.imshow(val) grid.cbar_axes[0].colorbar(im) for cax in grid.cbar_axes: cax.toggle_label(True) gwp=gwp+1 # plt.show() fig.savefig(os.path.join(dataset_dir, str(gwp) + '.png'), dpi=400, bbox_inches='tight', transparent=True)# from mpl_toolkits.axes_grid1 import AxesGrid# for k,v in featuremap.items():# vals = []# if isinstance(v, tuple):# for i in range(v[0].shape[1]):# vals.append(v[0].squeeze().float().cpu().numpy()[i]) # hr = F.upsample(v[1], scale_factor=2, mode='nearest')# for i in range(hr.shape[1]):# vals.append(hr.squeeze().float().cpu().numpy()[i])# else:# for i in range(v.shape[1]):# vals = v.squeeze().float().cpu().numpy() # fig = plt.figure(figsize=(15,5))# grid = AxesGrid(fig, 111,# nrows_ncols=(2, 8),# axes_pad=0.05,# share_all=True,# label_mode="L",# cbar_location="right",# cbar_mode="single",# )# for val, ax in zip(vals,grid):# im = ax.imshow(val, vmin=0, vmax=2)# grid.cbar_axes[0].colorbar(im)# for cax in grid.cbar_axes:# cax.toggle_label(True)# plt.show()# # fig.savefig(os.path.join(dataset_dir, k.split('/')[1] + '.png'), dpi=400, bbox_inches='tight', transparent=True)
运行代码后,就可以看到卷积网络每一层layer输出的feature map的形式,进而可以进一步的分析网络
发表评论
最新留言
感谢大佬
[***.8.128.20]2025年04月01日 21时51分58秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
[Unity][EXE]封装打包后怎么Debug错误显示output_log.txt
2021-05-09
使用promise封装wx:requset()
2021-05-09
图文追踪PlusToken资产转移行踪(一): BTC部分有1,203个流入交易所
2021-05-09
stm32h743iit6 cubmex 配置QSPI w25128模式问题
2021-05-09
让nginx支持文件上传的几种模式
2021-05-09
LeetCode 637 二叉树的层平均值-简单
2021-05-09
Redis-day2-五种数据结构类型与数据持久化AOF+RDB
2021-05-10
IOS开发Swift笔记16-错误处理
2021-05-10
Java 天气预报WebService
2021-05-10
redis中RDB和AOF的区别
2021-05-10
《STM32从零开始学习历程》——CAN相关结构体
2021-05-10
原生Javascript实现New方法
2021-05-10
Tomcat中jdk版本与项目版本不一致造成404错误以及Eclipse修改jdk版本
2021-05-10
配置SpringMVC中的视图解析器
2021-05-10
杭电OJ-2034(C)
2021-05-10
this.$router.push不起作用(this指向错误)
2021-05-10
Sublime安装px转rem插件
2021-05-10
IDEA上传Jar
2021-05-10