实验——基于pytorch的noise estimation、blur estimation、SR级联网络
发布日期:2021-05-10 14:19:07 浏览次数:17 分类:精选文章

本文共 46322 字,大约阅读时间需要 154 分钟。

目录


 

python train_sub.py -opt options/train/train_noise_blur_sr.json

tensorboard --logdir tb_logger/ --port 6008

处理数据的代码可以参考本人的GitHub()

 

setting

{  "name": "noiseestimation_blurestimation_SR" //  please remove "debug_" during training  , "tb_logger_dir": "sr_noise_blur"  , "use_tb_logger": true  , "model":"sr_noise_blur"  , "scale": 4  , "crop_scale": 4  , "gpu_ids": [3,4]//  , "init_type": "kaiming"////  , "finetune_type": "basic" //sft | basic  , "datasets": {    "train": {      "name": "DIV2K"      , "mode": "LRMRMATHR"      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub"      , "dataroot_MR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_blur_bicLRx4"//the target for the noise estimation      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_blur_bicLRx4_noiseALL"      , "dataroot_MAT": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_estimation"//the target for the blur estimation      , "subset_file": null      , "use_shuffle": true      , "n_workers": 8      , "batch_size": 24 // 16      , "HR_size": 128 // 128 | 192 | 96      , "noise_gt": true//residual for the noise      , "use_flip": true      , "use_rot": true    }  , "val": {      "name": "val_set5_x4_c03s08_mod4",      "mode": "LRHR",      "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5",      "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_blur_bicLRx4_noiseALL"    }  }  , "path": {    "root": "/home/guanwp/Blind_Restoration-master/sr_noise_blur"//    , "pretrain_model_G": null//    , "pretrain_model_sub_noise": null//    , "pretrain_model_sub_blur": null  }  , "network_G": {    "which_model_G": "sr_resnet" // sr_resnet | modulate_sr_resnet//    , "norm_type": "sft"    , "norm_type": null    , "mode": "CNA"    , "nf": 64    , "nb": 16    , "in_nc": 9    , "out_nc": 3//    , "gc": 32    , "group": 1//    , "gate_conv_bias": true  }  , "network_sub": {    "which_model_sub": "noise_subnet" // sr_resnet |noise_subnet//    , "norm_type": "adaptive_conv_res"    , "norm_type": "batch"//    , "norm_type": null    , "mode": "CNA"    , "nf": 64//    , "nb": 16    , "in_nc": 3    , "out_nc": 3    , "group": 1//    , "down_scale": 2  }  , "network_sub2": {    "which_model_sub": "blur_subnet" // sr_resnet | blur_subnet//    , "norm_type": "adaptive_conv_res"    , "norm_type": "batch"//    , "norm_type": null    , "mode": "CNA"    , "nf": 64//    , "nb": 16    , "in_nc": 6    , "out_nc": 3    , "group": 1//    , "down_scale": 2  }  , "train": {//    "lr_G": 1e-3    "lr_G": 1e-4    , "lr_scheme": "MultiStepLR"//    , "lr_steps": [200000, 400000, 600000, 800000]    , "lr_steps": [500000]//    , "lr_steps": [600000]//    , "lr_steps": [1000000]    , "lr_gamma": 0.1//    , "lr_gamma": 0.5    , "pixel_criterion_basic": "l2"    , "pixel_criterion_noise": "l2"    , "pixel_criterion_reg_noise": "tv"    , "pixel_criterion_blur": "l2"    , "pixel_criterion_reg_blur": "tv"    , "pixel_weight_basic": 1.0    , "pixel_weight_noise": 1.0    , "pixel_weight_blur": 1.0    , "val_freq": 1e3    , "manual_seed": 0    , "niter": 1e6//    , "niter": 6e5  }  , "logger": {    "print_freq": 200    , "save_checkpoint_freq": 1e3  }}

 

数据处理中的.mat文件

LRMRMATHR_dataset.py

import os.pathimport randomimport numpy as npimport cv2import torchimport torch.utils.data as dataimport data.util as utilfrom scipy.io import loadmatclass LRMRMATHRDataset(data.Dataset):    '''    Read LR, MR and HR image pair.    The pair is ensured by 'sorted' function, so please check the name convention.    '''    def __init__(self, opt):        super(LRMRMATHRDataset, self).__init__()        self.opt = opt        self.paths_LR = None        self.paths_MR = None        self.paths_HR = None        self.paths_MAT = None        self.LR_env = None  # environment for lmdb        self.MR_env = None        self.HR_env = None        self.MAT_env = None        self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR'])        self.MR_env, self.paths_MR = util.get_image_paths(opt['data_type'], opt['dataroot_MR'])        self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])        self.MAT_env, self.paths_MAT = util.get_image_paths(opt['data_type'], opt['dataroot_MAT'])        assert self.paths_HR, 'Error: HR path is empty.'        if self.paths_LR and self.paths_MR:            assert len(self.paths_LR) == len(self.paths_MR), \                'MR and LR datasets have different number of images - {}, {}.'.format(\                len(self.paths_LR), len(self.paths_MR))        self.random_scale_list = [1]    def __getitem__(self, index):        HR_path, LR_path, MR_path, MAT_path = None, None, None, None        scale = self.opt['scale']        HR_size = self.opt['HR_size']        # get HR image        HR_path = self.paths_HR[index]        img_HR = util.read_img(self.HR_env, HR_path)        # # modcrop in the validation / test phase        # if self.opt['phase'] != 'train':        #     img_HR = util.modcrop(img_HR, scale)        LR_path = self.paths_LR[index]        img_LR = util.read_img(self.LR_env, LR_path)        MR_path = self.paths_MR[index]        img_MR = util.read_img(self.MR_env, MR_path)        # get mat file        MAT_path = self.paths_MAT[index]        img_MAT = loadmat(MAT_path)['im_residual']        # kernel_gt = loadmat(MAT_path)['kernel_gt']        # img_MAT = np.zeros_like(img_LR)        if self.opt['noise_gt']:            img_MR = img_LR - img_MR        if self.opt['phase'] == 'train':            # if the image size is too small            H, W, C = img_LR.shape            LR_size = HR_size // scale            # randomly crop            rnd_h = random.randint(0, max(0, H - LR_size))            rnd_w = random.randint(0, max(0, W - LR_size))            img_MR = img_MR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]            img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]            img_MAT = img_MAT[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]            rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)            img_HR = img_HR[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]            # for ind, value in enumerate(kernel_gt):            #     img_MAT[:, :, ind] = np.tile(value, (LR_size, LR_size))            # augmentation - flip, rotate            img_MR, img_MAT, img_LR, img_HR = util.augment([img_MR, img_MAT, img_LR, img_HR], self.opt['use_flip'], \                                          self.opt['use_rot'])        # BGR to RGB, HWC to CHW, numpy to tensor        if img_HR.shape[2] == 3:            img_HR = img_HR[:, :, [2, 1, 0]]            img_LR = img_LR[:, :, [2, 1, 0]]            img_MR = img_MR[:, :, [2, 1, 0]]            img_MAT = img_MAT[:, :, [2, 1, 0]]        img_HR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HR, (2, 0, 1)))).float()        img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float()        img_MR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_MR, (2, 0, 1)))).float()        img_MAT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_MAT, (2, 0, 1)))).float()        return {'HR': img_HR, 'LR': img_LR, 'MR': img_MR, 'MAT': img_MAT, 'HR_path': HR_path, 'MR_path': MR_path,                'LR_path': LR_path, 'MAT_path': MAT_path}    def __len__(self):        return len(self.paths_HR)

LRMRHR_dataset.py

import os.pathimport randomimport numpy as npimport cv2import torchimport torch.utils.data as dataimport data.util as utilclass LRMRHRDataset(data.Dataset):    '''    Read LR, MR and HR image pair.    The pair is ensured by 'sorted' function, so please check the name convention.    '''    def __init__(self, opt):        super(LRMRHRDataset, self).__init__()        self.opt = opt        self.paths_LR = None        self.paths_MR = None        self.paths_HR = None        self.LR_env = None  # environment for lmdb        self.MR_env = None        self.HR_env = None        self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR'])        self.MR_env, self.paths_MR = util.get_image_paths(opt['data_type'], opt['dataroot_MR'])        self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])        assert self.paths_HR, 'Error: HR path is empty.'        if self.paths_LR and self.paths_MR:            assert len(self.paths_LR) == len(self.paths_MR), \                'MR and LR datasets have different number of images - {}, {}.'.format(\                len(self.paths_LR), len(self.paths_MR))        self.random_scale_list = [1]    def __getitem__(self, index):        HR_path, LR_path, MR_path = None, None, None        scale = self.opt['scale']        HR_size = self.opt['HR_size']        # get HR image        HR_path = self.paths_HR[index]        img_HR = util.read_img(self.HR_env, HR_path)        # modcrop in the validation / test phase        # if self.opt['phase'] != 'train':        #     img_HR = util.modcrop(img_HR, scale)        # change color space if necessary        if self.opt['color']:            img_HR = util.channel_convert(img_HR.shape[2], self.opt['color'], [img_HR])[0]        LR_path = self.paths_LR[index]        img_LR = util.read_img(self.LR_env, LR_path)        MR_path = self.paths_MR[index]        img_MR = util.read_img(self.MR_env, MR_path)        if self.opt['noise_gt']:            img_MR = img_LR - img_MR        if self.opt['phase'] == 'train':            # if the image size is too small            H, W, C = img_LR.shape            LR_size = HR_size // scale            # randomly crop            rnd_h = random.randint(0, max(0, H - LR_size))            rnd_w = random.randint(0, max(0, W - LR_size))            img_MR = img_MR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]            img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]            rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)            img_HR = img_HR[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]            # augmentation - flip, rotate            img_MR, img_LR, img_HR = util.augment([img_MR, img_LR, img_HR], self.opt['use_flip'], \                                          self.opt['use_rot'])        # channel conversion        if self.opt['color']:            # img_HR, img_LR, img_MR = util.channel_convert(C, self.opt['color'], [img_HR, img_LR, img_MR])            img_LR = util.channel_convert(C, self.opt['color'], [img_LR])[0]            img_MR = util.channel_convert(C, self.opt['color'], [img_MR])[0]        # BGR to RGB, HWC to CHW, numpy to tensor        if img_HR.shape[2] == 3:            img_HR = img_HR[:, :, [2, 1, 0]]            img_LR = img_LR[:, :, [2, 1, 0]]            img_MR = img_MR[:, :, [2, 1, 0]]        img_HR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HR, (2, 0, 1)))).float()        img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float()        img_MR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_MR, (2, 0, 1)))).float()        return {'HR': img_HR, 'LR': img_LR, 'MR': img_MR, 'HR_path': HR_path, 'MR_path': MR_path, 'LR_path': LR_path}    def __len__(self):        return len(self.paths_HR)

 

model

关键部分就是model结构的设计。需要到各网络的输出contact到一起

import osfrom collections import OrderedDictimport torchimport torch.nn as nnfrom torch.optim import lr_schedulerimport models.networks as networksfrom .base_model import BaseModelfrom .modules.loss import TVLossclass SRModel(BaseModel):    def __init__(self, opt):        super(SRModel, self).__init__(opt)        train_opt = opt['train']        finetune_type = opt['finetune_type']        # define network and load pretrained models        self.netG = networks.define_G(opt).to(self.device)        self.subnet_noise = networks.define_sub(opt).to(self.device)        self.subnet_blur = networks.define_sub2(opt).to(self.device)        self.load()        if self.is_train:            self.netG.train()            if finetune_type in ['basic', 'sft_basic', 'sft', 'sub_sft']:                self.subnet_noise.eval()                self.subnet_blur.eval()            else:                self.subnet_noise.train()                self.subnet_blur.train()            # loss on noise            loss_type_noise = train_opt['pixel_criterion_noise']            if loss_type_noise == 'l1':                self.cri_pix_noise = nn.L1Loss().to(self.device)            elif loss_type_noise == 'l2':                self.cri_pix_noise = nn.MSELoss().to(self.device)            else:                raise NotImplementedError('Noise loss type [{:s}] is not recognized.'.format(loss_type_noise))            self.l_pix_noise_w = train_opt['pixel_weight_noise']            loss_reg_noise = train_opt['pixel_criterion_reg_noise']            if loss_reg_noise == 'tv':                self.cri_pix_reg_noise = TVLoss(0.00001).to(self.device)            # loss on blur            loss_type_blur = train_opt['pixel_criterion_blur']            if loss_type_blur == 'l1':                self.cri_pix_blur = nn.L1Loss().to(self.device)            elif loss_type_blur == 'l2':                self.cri_pix_blur = nn.MSELoss().to(self.device)            else:                raise NotImplementedError('Blur loss type [{:s}] is not recognized.'.format(loss_type_blur))            self.l_pix_blur_w = train_opt['pixel_weight_blur']            loss_reg_blur = train_opt['pixel_criterion_reg_blur']            if loss_reg_blur == 'tv':                self.cri_pix_reg_blur = TVLoss(0.00001).to(self.device)            loss_type_basic = train_opt['pixel_criterion_basic']            if loss_type_basic == 'l1':                self.cri_pix_basic = nn.L1Loss().to(self.device)            elif loss_type_basic == 'l2':                self.cri_pix_basic = nn.MSELoss().to(self.device)            else:                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type_basic))            self.l_pix_basic_w = train_opt['pixel_weight_basic']            # optimizers            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0            self.optim_params = self.__define_grad_params(finetune_type)            self.optimizer_G = torch.optim.Adam(                self.optim_params, lr=train_opt['lr_G'], weight_decay=wd_G)            self.optimizers.append(self.optimizer_G)            # schedulers            if train_opt['lr_scheme'] == 'MultiStepLR':                for optimizer in self.optimizers:                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \                        train_opt['lr_steps'], train_opt['lr_gamma']))            else:                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')            self.log_dict = OrderedDict()        print('---------- Model initialized ------------------')        self.print_network()        print('-----------------------------------------------')    def feed_data(self, data, need_MR=True, need_MAT=True):        self.var_L = data['LR'].to(self.device)  # LR        self.real_H = data['HR'].to(self.device)  # HR        if need_MR:            self.mid_L = data['MR'].to(self.device)  # MR        if need_MAT:            self.real_blur = data['MAT'].to(self.device)    def __define_grad_params(self, finetune_type=None):        optim_params = []        if finetune_type == 'sft':            for k, v in self.netG.named_parameters():                v.requires_grad = False                if k.find('Gate') >= 0:                    v.requires_grad = True                    optim_params.append(v)                    print('we only optimize params: {}'.format(k))        else:            for k, v in self.netG.named_parameters():  # can optimize for a part of the model                if v.requires_grad:                    optim_params.append(v)                    print('params [{:s}] will optimize.'.format(k))                else:                    print('WARNING: params [{:s}] will not optimize.'.format(k))            for k, v in self.subnet_noise.named_parameters():  # can optimize for a part of the model                if v.requires_grad:                    optim_params.append(v)                    print('params [{:s}] will optimize.'.format(k))                else:                    print('WARNING: params [{:s}] will not optimize.'.format(k))            for k, v in self.subnet_blur.named_parameters():  # can optimize for a part of the model                if v.requires_grad:                    optim_params.append(v)                    print('params [{:s}] will optimize.'.format(k))                else:                    print('WARNING: params [{:s}] will not optimize.'.format(k))        return optim_params    def optimize_parameters(self, step):        self.optimizer_G.zero_grad()        self.fake_noise = self.subnet_noise(self.var_L)        l_pix_noise = self.l_pix_noise_w * self.cri_pix_noise(self.fake_noise, self.mid_L)        l_pix_noise = l_pix_noise + self.cri_pix_reg_noise(self.fake_noise)        input_noise = torch.cat((self.var_L, self.fake_noise), 1)        self.fake_blur = self.subnet_blur(input_noise)        l_pix_blur = self.l_pix_blur_w * self.cri_pix_blur(self.fake_blur*16, self.real_blur)        l_pix_blur = l_pix_blur + self.cri_pix_reg_blur(self.fake_blur)        input_noise_blur = torch.cat((input_noise, self.fake_blur), 1)        self.fake_H = self.netG(input_noise_blur)        l_pix_basic = self.l_pix_basic_w * self.cri_pix_basic(self.fake_H, self.real_H)        l_pix = l_pix_noise + l_pix_blur + l_pix_basic        l_pix.backward()        self.optimizer_G.step()        self.log_dict['l_pix'] = l_pix.item()    def test(self):        self.netG.eval()        self.subnet_noise.eval()        self.subnet_blur.eval()        if self.is_train:            for v in self.optim_params:                v.requires_grad = False        else:            for k, v in self.netG.named_parameters():                v.requires_grad = False            for k, v in self.subnet_noise.named_parameters():                v.requires_grad = False            for k, v in self.subnet_blur.named_parameters():                v.requires_grad = False        self.fake_noise = self.subnet_noise(self.var_L)        input_noise = torch.cat((self.var_L, self.fake_noise), 1)        self.fake_blur = self.subnet_blur(input_noise)        input_noise_blur = torch.cat((input_noise, self.fake_blur), 1)        self.fake_H = self.netG(input_noise_blur)        if self.is_train:            for v in self.optim_params:                v.requires_grad = True        else:            for k, v in self.netG.named_parameters():                v.requires_grad = True            for k, v in self.subnet_noise.named_parameters():                v.requires_grad = True            for k, v in self.subnet_blur.named_parameters():                v.requires_grad = True        self.netG.train()        if self.opt['finetune_type'] in ['basic', 'sft_basic', 'sft', 'sub_sft']:            self.subnet_noise.eval()            self.subnet_blur.eval()        else:            self.subnet_noise.train()            self.subnet_blur.eval()    # def test(self):    #     self.netG.eval()    #     for k, v in self.netG.named_parameters():    #         v.requires_grad = False    #     self.fake_H = self.netG(self.var_L)    #     for k, v in self.netG.named_parameters():    #         v.requires_grad = True    #     self.netG.train()    def get_current_log(self):        return self.log_dict    def get_current_visuals(self, need_HR=True):        out_dict = OrderedDict()        out_dict['LR'] = self.var_L.detach()[0].float().cpu()        out_dict['MR'] = self.fake_noise.detach()[0].float().cpu()        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()        if need_HR:            out_dict['HR'] = self.real_H.detach()[0].float().cpu()        return out_dict    def print_network(self):        # G        s, n = self.get_network_description(self.netG)        print('Number of parameters in G: {:,d}'.format(n))        if self.is_train:            message = '-------------- Generator --------------\n' + s + '\n'            network_path = os.path.join(self.save_dir, '../', 'network.txt')            with open(network_path, 'w') as f:                f.write(message)            # noise subnet            s, n = self.get_network_description(self.subnet_noise)            print('Number of parameters in noise subnet: {:,d}'.format(n))            message = '\n\n\n-------------- noise subnet --------------\n' + s + '\n'            with open(network_path, 'a') as f:                f.write(message)            # blur subnet            s, n = self.get_network_description(self.subnet_blur)            print('Number of parameters in blur subnet: {:,d}'.format(n))            message = '\n\n\n-------------- blur subnet --------------\n' + s + '\n'            with open(network_path, 'a') as f:                f.write(message)    def load(self):        load_path_G = self.opt['path']['pretrain_model_G']        load_path_sub_noise = self.opt['path']['pretrain_model_sub_noise']        load_path_sub_blur = self.opt['path']['pretrain_model_sub_blur']        if load_path_G is not None:            print('loading model for G [{:s}] ...'.format(load_path_G))            self.load_network(load_path_G, self.netG)        if load_path_sub_noise is not None:            print('loading model for noise subnet [{:s}] ...'.format(load_path_sub_noise))            self.load_network(load_path_sub_noise, self.subnet_noise)        if load_path_sub_blur is not None:            print('loading model for blur subnet [{:s}] ...'.format(load_path_sub_blur))            self.load_network(load_path_sub_blur, self.subnet_blur)    def save(self, iter_label):        self.save_network(self.save_dir, self.netG, 'G', iter_label)        self.save_network(self.save_dir, self.subnet_noise, 'sub_noise', iter_label)        self.save_network(self.save_dir, self.subnet_blur, 'sub_blur', iter_label)

 

network

至于网络的结构,blur和noise estimation subnetwork都是采用DNCNN的结构,而SR网络采用srresnet

在network中需要定义两个subnetwork

import functoolsimport torchimport torch.nn as nnfrom torch.nn import initimport models.modules.architecture as archimport models.modules.sft_arch as sft_arch##################### initialize####################def weights_init_normal(m, std=0.02):    classname = m.__class__.__name__    if classname.find('Conv') != -1:        init.normal_(m.weight.data, 0.0, std)        if m.bias is not None:            m.bias.data.zero_()    elif classname.find('Linear') != -1:        init.normal_(m.weight.data, 0.0, std)        if m.bias is not None:            m.bias.data.zero_()    elif classname.find('BatchNorm2d') != -1:        init.normal_(m.weight.data, 1.0, std)  # BN also uses norm        init.constant_(m.bias.data, 0.0)def weights_init_kaiming(m, scale=1):    classname = m.__class__.__name__    if classname.find('Conv') != -1:        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')        m.weight.data *= scale        if m.bias is not None:            m.bias.data.zero_()    elif classname.find('Linear') != -1:        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')        m.weight.data *= scale        if m.bias is not None:            m.bias.data.zero_()    elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1:        init.constant_(m.weight.data, 1.0)        init.constant_(m.bias.data, 0.0)    # elif classname.find('AdaptiveConvResNorm') != -1:    #     init.constant_(m.weight.data, 0.0)    #     if m.bias is not None:    #         m.bias.data.zero_()def weights_init_orthogonal(m):    classname = m.__class__.__name__    if classname.find('Conv') != -1:        init.orthogonal_(m.weight.data, gain=1)        if m.bias is not None:            m.bias.data.zero_()    elif classname.find('Linear') != -1:        init.orthogonal_(m.weight.data, gain=1)        if m.bias is not None:            m.bias.data.zero_()    elif classname.find('BatchNorm2d') != -1:        init.constant_(m.weight.data, 1.0)        init.constant_(m.bias.data, 0.0)def init_weights(net, init_type='kaiming', scale=1, std=0.02):    # scale for 'kaiming', std for 'normal'.    print('initialization method [{:s}]'.format(init_type))    if init_type == 'normal':        weights_init_normal_ = functools.partial(weights_init_normal, std=std)        net.apply(weights_init_normal_)    elif init_type == 'kaiming':        weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale)        net.apply(weights_init_kaiming_)    elif init_type == 'orthogonal':        net.apply(weights_init_orthogonal)    else:        raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))##################### define network##################### Generatordef define_G(opt):    gpu_ids = opt['gpu_ids']    opt_net = opt['network_G']    which_model = opt_net['which_model_G']    if which_model == 'sr_resnet':  # SRResNet        netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')    elif which_model == 'modulate_sr_resnet':        netG = arch.ModulateSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],                                     upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'],                                     upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'],                                     gate_conv_bias=opt_net['gate_conv_bias'])    elif which_model == 'arcnn':        netG = arch.ARCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],                             norm_type=opt_net['norm_type'], mode=opt_net['mode'], ada_ksize=opt_net['ada_ksize'])    elif which_model == 'srcnn':        netG = arch.SRCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],                             norm_type=opt_net['norm_type'], mode=opt_net['mode'], ada_ksize=opt_net['ada_ksize'])    elif which_model == 'denoise_resnet':        netG = arch.DenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],                                  upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'],                                  upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'],                                  down_scale=opt_net['down_scale'], fea_norm=opt_net['fea_norm'],                                  upsample_norm=opt_net['upsample_norm'])    elif which_model == 'modulate_denoise_resnet':        netG = arch.ModulateDenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],                                          upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'],                                          upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'],                                          gate_conv_bias=opt_net['gate_conv_bias'])    elif which_model == 'noise_subnet':        netG = arch.NoiseSubNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],                                norm_type=opt_net['norm_type'], mode=opt_net['mode'])    elif which_model == 'cond_denoise_resnet':        netG = arch.CondDenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],                                      upscale=opt_net['scale'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'],                                      down_scale=opt_net['down_scale'], num_classes=opt_net['num_classes'],                                      norm_type=opt_net['norm_type'])    elif which_model == 'adabn_denoise_resnet':        netG = arch.AdaptiveDenoiseResNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'], nb=opt_net['nb'],                                          upscale=opt_net['scale'], down_scale=opt_net['down_scale'])    elif which_model == 'sft_arch':  # SFT-GAN        netG = sft_arch.SFT_Net()    elif which_model == 'RRDB_net':  # RRDB        netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],            nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],            act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')    else:        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))    if opt['init_type'] is not None:        init_weights(netG, init_type=opt['init_type'], scale=0.1)    if gpu_ids:        assert torch.cuda.is_available()        netG = nn.DataParallel(netG)    return netGdef define_sub(opt):    gpu_ids = opt['gpu_ids']    opt_net = opt['network_sub']    which_model = opt_net['which_model_sub']    if which_model == 'noise_subnet':        subnet = arch.NoiseSubNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],                                norm_type=opt_net['norm_type'], mode=opt_net['mode'])    else:        raise NotImplementedError('subnet model [{:s}] not recognized'.format(which_model))    if gpu_ids:        assert torch.cuda.is_available()        subnet = nn.DataParallel(subnet)    return subnetdef define_sub2(opt):    gpu_ids = opt['gpu_ids']    opt_net = opt['network_sub2']    which_model = opt_net['which_model_sub']    if which_model == 'blur_subnet':        subnet = arch.NoiseSubNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],                                norm_type=opt_net['norm_type'], mode=opt_net['mode'])    elif which_model == 'denoise_resnet':        subnet = arch.DenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],                                    upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'],                                    upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'],                                    down_scale=opt_net['down_scale'], fea_norm=opt_net['fea_norm'],                                    upsample_norm=opt_net['upsample_norm'])    else:        raise NotImplementedError('subnet model [{:s}] not recognized'.format(which_model))    if gpu_ids:        assert torch.cuda.is_available()        subnet = nn.DataParallel(subnet)    return subnet# Discriminatordef define_D(opt):    gpu_ids = opt['gpu_ids']    opt_net = opt['network_D']    which_model = opt_net['which_model_D']    if which_model == 'discriminator_vgg_128':        netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \            norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])    elif which_model == 'dis_acd':  # sft-gan, Auxiliary Classifier Discriminator        netD = sft_arch.ACD_VGG_BN_96()    elif which_model == 'discriminator_vgg_96':        netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \            norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])    elif which_model == 'discriminator_vgg_192':        netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \            norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])    elif which_model == 'discriminator_vgg_128_SN':        netD = arch.Discriminator_VGG_128_SN()    else:        raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))    init_weights(netD, init_type='kaiming', scale=1)    if gpu_ids:        netD = nn.DataParallel(netD)    return netDdef define_F(opt, use_bn=False):    gpu_ids = opt['gpu_ids']    device = torch.device('cuda' if gpu_ids else 'cpu')    # pytorch pretrained VGG19-54, before ReLU.    if use_bn:        feature_layer = 49    else:        feature_layer = 34    netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \        use_input_norm=True, device=device)    # netF = arch.ResNet101FeatureExtractor(use_input_norm=True, device=device)    if gpu_ids:        netF = nn.DataParallel(netF)    netF.eval()  # No need to train    return netF

网络结构

import mathimport torchimport torch.nn as nnimport torchvisionimport torch.nn.functional as Ffrom . import block as Bfrom . import spectral_norm as SNfrom . import adaptive_norm as AN##################### Generator####################class SRCNN(nn.Module):    def __init__(self, in_nc, out_nc, nf, norm_type='batch', act_type='relu', mode='CNA', ada_ksize=None):        super(SRCNN, self).__init__()        fea_conv = B.conv_block(in_nc, nf, kernel_size=9, norm_type=norm_type, act_type=act_type, mode=mode                                , ada_ksize=ada_ksize)        mapping_conv = B.conv_block(nf, nf // 2, kernel_size=1, norm_type=norm_type, act_type=act_type,                                    mode=mode, ada_ksize=ada_ksize)        HR_conv = B.conv_block(nf // 2, out_nc, kernel_size=5, norm_type=norm_type, act_type=None,                               mode=mode, ada_ksize=ada_ksize)        self.model = B.sequential(fea_conv, mapping_conv, HR_conv)    def forward(self, x):        x = self.model(x)        return xclass ARCNN(nn.Module):    def __init__(self, in_nc, out_nc, nf, norm_type='batch', act_type='relu', mode='CNA', ada_ksize=None):        super(ARCNN, self).__init__()        fea_conv = B.conv_block(in_nc, nf, kernel_size=9, norm_type=norm_type, act_type=act_type, mode=mode                                , ada_ksize=ada_ksize)        conv1 = B.conv_block(nf, nf // 2, kernel_size=7, norm_type=norm_type, act_type=act_type,                             mode=mode, ada_ksize=ada_ksize)        conv2 = B.conv_block(nf // 2, nf // 4, kernel_size=1, norm_type=norm_type, act_type=act_type,                             mode=mode, ada_ksize=ada_ksize)        HR_conv = B.conv_block(nf // 4, out_nc, kernel_size=5, norm_type=norm_type, act_type=None,                               mode=mode, ada_ksize=ada_ksize)        self.model = B.sequential(fea_conv, conv1, conv2, HR_conv)    def forward(self, x):        x = self.model(x)        return xclass SRResNet(nn.Module):    def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='batch', act_type='relu', \            mode='NAC', res_scale=1, upsample_mode='upconv'):        super(SRResNet, self).__init__()        n_upscale = int(math.log(upscale, 2))        if upscale == 3:            n_upscale = 1        fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)        resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,\            mode=mode, res_scale=res_scale) for _ in range(nb)]        LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)        if upsample_mode == 'upconv':            upsample_block = B.upconv_blcok        elif upsample_mode == 'pixelshuffle':            upsample_block = B.pixelshuffle_block        else:            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))        if upscale == 3:            upsampler = upsample_block(nf, nf, 3, act_type=act_type)        else:            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)        self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)),\            *upsampler, HR_conv0, HR_conv1)    def forward(self, x):        x = self.model(x)        return xclass ModulateSRResNet(nn.Module):    def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='sft', act_type='relu',                 mode='CNA', res_scale=1, upsample_mode='upconv', gate_conv_bias=True, ada_ksize=None):        super(ModulateSRResNet, self).__init__()        n_upscale = int(math.log(upscale, 2))        if upscale == 3:            n_upscale = 1        self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, stride=1)        resnet_blocks = [B.TwoStreamSRResNet(nf, nf, nf, norm_type=norm_type, act_type=act_type,                         mode=mode, res_scale=res_scale, gate_conv_bias=gate_conv_bias,                                             ada_ksize=ada_ksize, input_dim=in_nc) for _ in range(nb)]        self.LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode)        if norm_type == 'sft':            self.LR_norm = AN.GateNonLinearLayer(in_nc, conv_bias=gate_conv_bias)        elif norm_type == 'sft_conv':            self.LR_norm = AN.MetaLayer(in_nc, conv_bias=gate_conv_bias, kernel_size=ada_ksize)        if upsample_mode == 'upconv':            upsample_block = B.upconv_blcok        elif upsample_mode == 'pixelshuffle':            upsample_block = B.pixelshuffle_block        else:            raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)        if upscale == 3:            upsampler = upsample_block(nf, nf, 3, act_type=act_type)        else:            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)        self.norm_branch = B.sequential(*resnet_blocks)        self.HR_branch = B.sequential(*upsampler, HR_conv0, HR_conv1)    def forward(self, x):        fea = self.fea_conv(x[0])        fea_res_block, _ = self.norm_branch((fea, x[1]))        fea_LR = self.LR_conv(fea_res_block)        res = self.LR_norm((fea_LR, x[1]))        out = self.HR_branch(fea+res)        return outclass DenoiseResNet(nn.Module):    """    jingwen's addition    denoise Resnet    """    def __init__(self, in_nc, out_nc, nf, nb, upscale=1, norm_type='batch', act_type='relu',                 mode='CNA', res_scale=1, upsample_mode='upconv', ada_ksize=None, down_scale=2,                 fea_norm=None, upsample_norm=None):        super(DenoiseResNet, self).__init__()        n_upscale = int(math.log(down_scale, 2))        if down_scale == 3:            n_upscale = 1        fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=fea_norm, act_type=None, stride=down_scale,                                ada_ksize=ada_ksize)        resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,                         mode=mode, res_scale=res_scale, ada_ksize=ada_ksize) for _ in range(nb)]        LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode                               , ada_ksize=ada_ksize)        # LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode        #                        , ada_ksize=ada_ksize)        if upsample_mode == 'upconv':            upsample_block = B.upconv_blcok        elif upsample_mode == 'pixelshuffle':            upsample_block = B.pixelshuffle_block        else:            raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)        if down_scale == 3:            upsampler = upsample_block(nf, nf, 3, act_type=act_type, norm_type=upsample_norm, ada_ksize=ada_ksize)        else:            upsampler = [upsample_block(nf, nf, act_type=act_type, norm_type=upsample_norm, ada_ksize=ada_ksize) for _ in range(n_upscale)]        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=upsample_norm, act_type=act_type, ada_ksize=ada_ksize)        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=upsample_norm, act_type=None, ada_ksize=ada_ksize)        self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)),                                  *upsampler, HR_conv0, HR_conv1)    def forward(self, x):        x = self.model(x)        return xclass ModulateDenoiseResNet(nn.Module):    def __init__(self, in_nc, out_nc, nf, nb, upscale=1, norm_type='sft', act_type='relu',                 mode='CNA', res_scale=1, upsample_mode='upconv', gate_conv_bias=True, ada_ksize=None):        super(ModulateDenoiseResNet, self).__init__()        self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, stride=2)        resnet_blocks = [B.TwoStreamSRResNet(nf, nf, nf, norm_type=norm_type, act_type=act_type,                         mode=mode, res_scale=res_scale, gate_conv_bias=gate_conv_bias,                                             ada_ksize=ada_ksize, input_dim=in_nc) for _ in range(nb)]        LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode)        if norm_type == 'sft':            LR_norm = AN.GateNonLinearLayer(in_nc, conv_bias=gate_conv_bias)        elif norm_type == 'sft_conv':            LR_norm = AN.MetaLayer(in_nc, conv_bias=gate_conv_bias, kernel_size=ada_ksize)        if upsample_mode == 'upconv':            upsample_block = B.upconv_blcok        elif upsample_mode == 'pixelshuffle':            upsample_block = B.pixelshuffle_block        else:            raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)        upsampler = upsample_block(nf, nf, act_type=act_type)        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)        self.norm_branch = B.sequential(*resnet_blocks)        self.LR_conv = LR_conv        self.LR_norm = LR_norm        self.HR_branch = B.sequential(upsampler, HR_conv0, HR_conv1)    def forward(self, x):        fea = self.fea_conv(x[0])        fea_res_block, _ = self.norm_branch((fea, x[1]))        fea_LR = self.LR_conv(fea_res_block)        res = self.LR_norm((fea_LR, x[1]))        out = self.HR_branch(fea+res)        return outclass NoiseSubNet(nn.Module):    def __init__(self, in_nc, out_nc, nf, nb, norm_type='batch', act_type='relu', mode='CNA'):        super(NoiseSubNet, self).__init__()        degration_block = [B.conv_block(in_nc, nf, kernel_size=3, norm_type=norm_type, act_type=act_type, mode=mode)]        degration_block.extend([B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=act_type, mode=mode)                                for _ in range(15)])        degration_block.append(B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, mode=mode))        self.degration_block = B.sequential(*degration_block)    def forward(self, x):        deg_estimate = self.degration_block(x)        return deg_estimateclass CondDenoiseResNet(nn.Module):    """    jingwen's addition    denoise Resnet    """    def __init__(self, in_nc, out_nc, nf, nb, upscale=1, res_scale=1, down_scale=2, num_classes=1, ada_ksize=None                 ,upsample_mode='upconv', act_type='relu', norm_type='cond_adaptive_conv_res'):        super(CondDenoiseResNet, self).__init__()        n_upscale = int(math.log(down_scale, 2))        if down_scale == 3:            n_upscale = 1        self.fea_conv = nn.Conv2d(in_nc, nf, kernel_size=3, stride=down_scale, padding=1)        resnet_blocks = [B.CondResNetBlock(nf, nf, nf, num_classes=num_classes, ada_ksize=ada_ksize,                                           norm_type=norm_type, act_type=act_type) for _ in range(nb)]        self.resnet_blocks = B.sequential(*resnet_blocks)        self.LR_conv = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1)        if norm_type == 'cond_adaptive_conv_res':            self.cond_adaptive = AN.CondAdaptiveConvResNorm(nf, num_classes=num_classes)        elif norm_type == "interp_adaptive_conv_res":            self.cond_adaptive = AN.InterpAdaptiveResNorm(nf, ada_ksize)        elif norm_type == "cond_instance":            self.cond_adaptive = AN.CondInstanceNorm2d(nf, num_classes=num_classes)        elif norm_type == "cond_transform_res":            self.cond_adaptive = AN.CondResTransformer(nf, ada_ksize, num_classes=num_classes)        if upsample_mode == 'upconv':            upsample_block = B.upconv_blcok        elif upsample_mode == 'pixelshuffle':            upsample_block = B.pixelshuffle_block        else:            raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)        if down_scale == 3:            upsampler = upsample_block(nf, nf, 3, act_type=act_type)        else:            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)        self.upsample = B.sequential(*upsampler, HR_conv0, HR_conv1)    def forward(self, x, y):        # the first feature extraction        fea = self.fea_conv(x)        fea1, _ = self.resnet_blocks((fea, y))        fea2 = self.LR_conv(fea1)        fea3 = self.cond_adaptive(fea2, y)        # res        out = self.upsample(fea3 + fea)        return outclass AdaptiveDenoiseResNet(nn.Module):    """    jingwen's addition    adabn    """    def __init__(self, in_nc, nf, nb, upscale=1, res_scale=1, down_scale=2):        super(AdaptiveDenoiseResNet, self).__init__()        self.fea_conv = nn.Conv2d(in_nc, nf, kernel_size=3, stride=down_scale, padding=1)        resnet_blocks = [B.AdaptiveResNetBlock(nf, nf, nf, res_scale=res_scale) for _ in range(nb)]        self.resnet_blocks = B.sequential(*resnet_blocks)        self.LR_conv = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1)        self.batch_norm = nn.BatchNorm2d(nf, affine=True, track_running_stats=True, momentum=0)    def forward(self, x):        fea_list = [self.fea_conv(data.unsqueeze_(0)) for data in x]        fea_resblock_list = self.resnet_blocks(fea_list)        fea_LR_list = [self.LR_conv(fea) for fea in fea_resblock_list]        fea_mean, fea_var = B.computing_mean_variance(fea_LR_list)        batch_norm_dict = self.batch_norm.state_dict()        batch_norm_dict['running_mean'] = fea_mean        batch_norm_dict['running_var'] = fea_var        self.batch_norm.load_state_dict(batch_norm_dict)        return None

 

experiment

 

 

上一篇:关于高斯模糊核
下一篇:实验——基于pytorch的blind restoration联合网络训练

发表评论

最新留言

感谢大佬
[***.8.128.20]2025年04月17日 18时54分19秒