查看卷积网络每一层的feature map的代码
发布日期:2021-05-10 14:16:22 浏览次数:24 分类:精选文章

本文共 4201 字,大约阅读时间需要 14 分钟。

import osimport sysimport pdbimport loggingimport timeimport torchimport argparseimport numpy as npimport torch.nn as nnimport torch.nn.functional as Ffrom collections import OrderedDictimport options.options as optionimport utils.util as utilfrom data.util import bgr2ycbcrfrom data import create_dataset, create_dataloaderfrom models import create_modelfrom models.modules import block as Bimport matplotlib.pyplot as plt# optionsopt = 'options/test/test_sr.json'opt = option.parse(opt, is_train=False)util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))opt = option.dict_to_nonedict(opt)# Create test dataset and dataloadertest_loaders = []for phase, dataset_opt in sorted(opt['datasets'].items()):    test_set = create_dataset(dataset_opt)    test_loader = create_dataloader(test_set, dataset_opt)    print('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))    test_loaders.append(test_loader)# Create modelmodel = create_model(opt)# Register hook for featrue map.def save_feature(name):    def hook(module, input, output):        featuremap[name] = output    return hook# Set register hookconv_idx = 0featuremap = OrderedDict()for m in model.netG.module.model.modules():    if m._get_name() == '********':        conv_idx += 1        m.register_forward_hook(save_feature('conv_' + str(conv_idx)))#可在module前向传播或反向传播时注册钩子# print(conv_idx)# exit()for test_loader in test_loaders:    test_set_name = test_loader.dataset.opt['name']    test_start_time = time.time()    dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)    util.mkdir(dataset_dir)    test_results = OrderedDict()    test_results['psnr'] = []    test_results['ssim'] = []    test_results['psnr_y'] = []    test_results['ssim_y'] = []    for data in test_loader:        need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True        model.feed_data(data, need_HR=need_HR)        img_path = data['LR_path'][0]        img_name = os.path.splitext(os.path.basename(img_path))[0]        model.test()  # test        visuals = model.get_current_visuals(need_HR=need_HR)        sr_img = util.tensor2img(visuals['SR'])  # uint8gwp=0from mpl_toolkits.axes_grid1 import AxesGridfor k,v in featuremap.items():    # print(v[1].shape)    # exit()    vals = v[0].squeeze().float().cpu().numpy()    print(vals.shape)    # exit()    fig = plt.figure(figsize=(15,5))    grid = AxesGrid(fig, 111,                    nrows_ncols=(2, 16),                    axes_pad=0.05,                    share_all=True,                    label_mode="L",                    cbar_location="right",                    cbar_mode="single",                    )    for val, ax in zip(vals,grid):        im = ax.imshow(val)    grid.cbar_axes[0].colorbar(im)    for cax in grid.cbar_axes:        cax.toggle_label(True)    gwp=gwp+1    # plt.show()    fig.savefig(os.path.join(dataset_dir, str(gwp) + '.png'), dpi=400, bbox_inches='tight', transparent=True)# from mpl_toolkits.axes_grid1 import AxesGrid# for k,v in featuremap.items():#     vals = []#     if isinstance(v, tuple):#         for i in range(v[0].shape[1]):#             vals.append(v[0].squeeze().float().cpu().numpy()[i])            #         hr = F.upsample(v[1], scale_factor=2, mode='nearest')#         for i in range(hr.shape[1]):#             vals.append(hr.squeeze().float().cpu().numpy()[i])#     else:#         for i in range(v.shape[1]):#             vals = v.squeeze().float().cpu().numpy()    #     fig = plt.figure(figsize=(15,5))#     grid = AxesGrid(fig, 111,#                     nrows_ncols=(2, 8),#                     axes_pad=0.05,#                     share_all=True,#                     label_mode="L",#                     cbar_location="right",#                     cbar_mode="single",#                     )#     for val, ax in zip(vals,grid):#         im = ax.imshow(val, vmin=0, vmax=2)#     grid.cbar_axes[0].colorbar(im)#     for cax in grid.cbar_axes:#         cax.toggle_label(True)#     plt.show()#     # fig.savefig(os.path.join(dataset_dir, k.split('/')[1] + '.png'), dpi=400, bbox_inches='tight', transparent=True)

运行代码后,就可以看到卷积网络每一层layer输出的feature map的形式,进而可以进一步的分析网络

 

 

上一篇:学习笔记之——高频谱效率频分复用(SEFDM)
下一篇:学习笔记之——深度强化学习(Deep Reinforcement Learning)

发表评论

最新留言

感谢大佬
[***.8.128.20]2025年04月01日 21时51分58秒