
实验——基于pytorch的noise estimation、blur estimation、SR级联网络
发布日期:2021-05-10 14:19:07
浏览次数:17
分类:精选文章
本文共 46322 字,大约阅读时间需要 154 分钟。
目录
python train_sub.py -opt options/train/train_noise_blur_sr.json
tensorboard --logdir tb_logger/ --port 6008
处理数据的代码可以参考本人的GitHub()
setting
{ "name": "noiseestimation_blurestimation_SR" // please remove "debug_" during training , "tb_logger_dir": "sr_noise_blur" , "use_tb_logger": true , "model":"sr_noise_blur" , "scale": 4 , "crop_scale": 4 , "gpu_ids": [3,4]// , "init_type": "kaiming"//// , "finetune_type": "basic" //sft | basic , "datasets": { "train": { "name": "DIV2K" , "mode": "LRMRMATHR" , "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub" , "dataroot_MR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_blur_bicLRx4"//the target for the noise estimation , "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_blur_bicLRx4_noiseALL" , "dataroot_MAT": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_estimation"//the target for the blur estimation , "subset_file": null , "use_shuffle": true , "n_workers": 8 , "batch_size": 24 // 16 , "HR_size": 128 // 128 | 192 | 96 , "noise_gt": true//residual for the noise , "use_flip": true , "use_rot": true } , "val": { "name": "val_set5_x4_c03s08_mod4", "mode": "LRHR", "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5", "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_blur_bicLRx4_noiseALL" } } , "path": { "root": "/home/guanwp/Blind_Restoration-master/sr_noise_blur"// , "pretrain_model_G": null// , "pretrain_model_sub_noise": null// , "pretrain_model_sub_blur": null } , "network_G": { "which_model_G": "sr_resnet" // sr_resnet | modulate_sr_resnet// , "norm_type": "sft" , "norm_type": null , "mode": "CNA" , "nf": 64 , "nb": 16 , "in_nc": 9 , "out_nc": 3// , "gc": 32 , "group": 1// , "gate_conv_bias": true } , "network_sub": { "which_model_sub": "noise_subnet" // sr_resnet |noise_subnet// , "norm_type": "adaptive_conv_res" , "norm_type": "batch"// , "norm_type": null , "mode": "CNA" , "nf": 64// , "nb": 16 , "in_nc": 3 , "out_nc": 3 , "group": 1// , "down_scale": 2 } , "network_sub2": { "which_model_sub": "blur_subnet" // sr_resnet | blur_subnet// , "norm_type": "adaptive_conv_res" , "norm_type": "batch"// , "norm_type": null , "mode": "CNA" , "nf": 64// , "nb": 16 , "in_nc": 6 , "out_nc": 3 , "group": 1// , "down_scale": 2 } , "train": {// "lr_G": 1e-3 "lr_G": 1e-4 , "lr_scheme": "MultiStepLR"// , "lr_steps": [200000, 400000, 600000, 800000] , "lr_steps": [500000]// , "lr_steps": [600000]// , "lr_steps": [1000000] , "lr_gamma": 0.1// , "lr_gamma": 0.5 , "pixel_criterion_basic": "l2" , "pixel_criterion_noise": "l2" , "pixel_criterion_reg_noise": "tv" , "pixel_criterion_blur": "l2" , "pixel_criterion_reg_blur": "tv" , "pixel_weight_basic": 1.0 , "pixel_weight_noise": 1.0 , "pixel_weight_blur": 1.0 , "val_freq": 1e3 , "manual_seed": 0 , "niter": 1e6// , "niter": 6e5 } , "logger": { "print_freq": 200 , "save_checkpoint_freq": 1e3 }}
数据处理中的.mat文件
LRMRMATHR_dataset.py
import os.pathimport randomimport numpy as npimport cv2import torchimport torch.utils.data as dataimport data.util as utilfrom scipy.io import loadmatclass LRMRMATHRDataset(data.Dataset): ''' Read LR, MR and HR image pair. The pair is ensured by 'sorted' function, so please check the name convention. ''' def __init__(self, opt): super(LRMRMATHRDataset, self).__init__() self.opt = opt self.paths_LR = None self.paths_MR = None self.paths_HR = None self.paths_MAT = None self.LR_env = None # environment for lmdb self.MR_env = None self.HR_env = None self.MAT_env = None self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR']) self.MR_env, self.paths_MR = util.get_image_paths(opt['data_type'], opt['dataroot_MR']) self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR']) self.MAT_env, self.paths_MAT = util.get_image_paths(opt['data_type'], opt['dataroot_MAT']) assert self.paths_HR, 'Error: HR path is empty.' if self.paths_LR and self.paths_MR: assert len(self.paths_LR) == len(self.paths_MR), \ 'MR and LR datasets have different number of images - {}, {}.'.format(\ len(self.paths_LR), len(self.paths_MR)) self.random_scale_list = [1] def __getitem__(self, index): HR_path, LR_path, MR_path, MAT_path = None, None, None, None scale = self.opt['scale'] HR_size = self.opt['HR_size'] # get HR image HR_path = self.paths_HR[index] img_HR = util.read_img(self.HR_env, HR_path) # # modcrop in the validation / test phase # if self.opt['phase'] != 'train': # img_HR = util.modcrop(img_HR, scale) LR_path = self.paths_LR[index] img_LR = util.read_img(self.LR_env, LR_path) MR_path = self.paths_MR[index] img_MR = util.read_img(self.MR_env, MR_path) # get mat file MAT_path = self.paths_MAT[index] img_MAT = loadmat(MAT_path)['im_residual'] # kernel_gt = loadmat(MAT_path)['kernel_gt'] # img_MAT = np.zeros_like(img_LR) if self.opt['noise_gt']: img_MR = img_LR - img_MR if self.opt['phase'] == 'train': # if the image size is too small H, W, C = img_LR.shape LR_size = HR_size // scale # randomly crop rnd_h = random.randint(0, max(0, H - LR_size)) rnd_w = random.randint(0, max(0, W - LR_size)) img_MR = img_MR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :] img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :] img_MAT = img_MAT[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :] rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) img_HR = img_HR[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :] # for ind, value in enumerate(kernel_gt): # img_MAT[:, :, ind] = np.tile(value, (LR_size, LR_size)) # augmentation - flip, rotate img_MR, img_MAT, img_LR, img_HR = util.augment([img_MR, img_MAT, img_LR, img_HR], self.opt['use_flip'], \ self.opt['use_rot']) # BGR to RGB, HWC to CHW, numpy to tensor if img_HR.shape[2] == 3: img_HR = img_HR[:, :, [2, 1, 0]] img_LR = img_LR[:, :, [2, 1, 0]] img_MR = img_MR[:, :, [2, 1, 0]] img_MAT = img_MAT[:, :, [2, 1, 0]] img_HR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HR, (2, 0, 1)))).float() img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float() img_MR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_MR, (2, 0, 1)))).float() img_MAT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_MAT, (2, 0, 1)))).float() return {'HR': img_HR, 'LR': img_LR, 'MR': img_MR, 'MAT': img_MAT, 'HR_path': HR_path, 'MR_path': MR_path, 'LR_path': LR_path, 'MAT_path': MAT_path} def __len__(self): return len(self.paths_HR)
LRMRHR_dataset.py
import os.pathimport randomimport numpy as npimport cv2import torchimport torch.utils.data as dataimport data.util as utilclass LRMRHRDataset(data.Dataset): ''' Read LR, MR and HR image pair. The pair is ensured by 'sorted' function, so please check the name convention. ''' def __init__(self, opt): super(LRMRHRDataset, self).__init__() self.opt = opt self.paths_LR = None self.paths_MR = None self.paths_HR = None self.LR_env = None # environment for lmdb self.MR_env = None self.HR_env = None self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR']) self.MR_env, self.paths_MR = util.get_image_paths(opt['data_type'], opt['dataroot_MR']) self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR']) assert self.paths_HR, 'Error: HR path is empty.' if self.paths_LR and self.paths_MR: assert len(self.paths_LR) == len(self.paths_MR), \ 'MR and LR datasets have different number of images - {}, {}.'.format(\ len(self.paths_LR), len(self.paths_MR)) self.random_scale_list = [1] def __getitem__(self, index): HR_path, LR_path, MR_path = None, None, None scale = self.opt['scale'] HR_size = self.opt['HR_size'] # get HR image HR_path = self.paths_HR[index] img_HR = util.read_img(self.HR_env, HR_path) # modcrop in the validation / test phase # if self.opt['phase'] != 'train': # img_HR = util.modcrop(img_HR, scale) # change color space if necessary if self.opt['color']: img_HR = util.channel_convert(img_HR.shape[2], self.opt['color'], [img_HR])[0] LR_path = self.paths_LR[index] img_LR = util.read_img(self.LR_env, LR_path) MR_path = self.paths_MR[index] img_MR = util.read_img(self.MR_env, MR_path) if self.opt['noise_gt']: img_MR = img_LR - img_MR if self.opt['phase'] == 'train': # if the image size is too small H, W, C = img_LR.shape LR_size = HR_size // scale # randomly crop rnd_h = random.randint(0, max(0, H - LR_size)) rnd_w = random.randint(0, max(0, W - LR_size)) img_MR = img_MR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :] img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :] rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) img_HR = img_HR[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :] # augmentation - flip, rotate img_MR, img_LR, img_HR = util.augment([img_MR, img_LR, img_HR], self.opt['use_flip'], \ self.opt['use_rot']) # channel conversion if self.opt['color']: # img_HR, img_LR, img_MR = util.channel_convert(C, self.opt['color'], [img_HR, img_LR, img_MR]) img_LR = util.channel_convert(C, self.opt['color'], [img_LR])[0] img_MR = util.channel_convert(C, self.opt['color'], [img_MR])[0] # BGR to RGB, HWC to CHW, numpy to tensor if img_HR.shape[2] == 3: img_HR = img_HR[:, :, [2, 1, 0]] img_LR = img_LR[:, :, [2, 1, 0]] img_MR = img_MR[:, :, [2, 1, 0]] img_HR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HR, (2, 0, 1)))).float() img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float() img_MR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_MR, (2, 0, 1)))).float() return {'HR': img_HR, 'LR': img_LR, 'MR': img_MR, 'HR_path': HR_path, 'MR_path': MR_path, 'LR_path': LR_path} def __len__(self): return len(self.paths_HR)
model
关键部分就是model结构的设计。需要到各网络的输出contact到一起
import osfrom collections import OrderedDictimport torchimport torch.nn as nnfrom torch.optim import lr_schedulerimport models.networks as networksfrom .base_model import BaseModelfrom .modules.loss import TVLossclass SRModel(BaseModel): def __init__(self, opt): super(SRModel, self).__init__(opt) train_opt = opt['train'] finetune_type = opt['finetune_type'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) self.subnet_noise = networks.define_sub(opt).to(self.device) self.subnet_blur = networks.define_sub2(opt).to(self.device) self.load() if self.is_train: self.netG.train() if finetune_type in ['basic', 'sft_basic', 'sft', 'sub_sft']: self.subnet_noise.eval() self.subnet_blur.eval() else: self.subnet_noise.train() self.subnet_blur.train() # loss on noise loss_type_noise = train_opt['pixel_criterion_noise'] if loss_type_noise == 'l1': self.cri_pix_noise = nn.L1Loss().to(self.device) elif loss_type_noise == 'l2': self.cri_pix_noise = nn.MSELoss().to(self.device) else: raise NotImplementedError('Noise loss type [{:s}] is not recognized.'.format(loss_type_noise)) self.l_pix_noise_w = train_opt['pixel_weight_noise'] loss_reg_noise = train_opt['pixel_criterion_reg_noise'] if loss_reg_noise == 'tv': self.cri_pix_reg_noise = TVLoss(0.00001).to(self.device) # loss on blur loss_type_blur = train_opt['pixel_criterion_blur'] if loss_type_blur == 'l1': self.cri_pix_blur = nn.L1Loss().to(self.device) elif loss_type_blur == 'l2': self.cri_pix_blur = nn.MSELoss().to(self.device) else: raise NotImplementedError('Blur loss type [{:s}] is not recognized.'.format(loss_type_blur)) self.l_pix_blur_w = train_opt['pixel_weight_blur'] loss_reg_blur = train_opt['pixel_criterion_reg_blur'] if loss_reg_blur == 'tv': self.cri_pix_reg_blur = TVLoss(0.00001).to(self.device) loss_type_basic = train_opt['pixel_criterion_basic'] if loss_type_basic == 'l1': self.cri_pix_basic = nn.L1Loss().to(self.device) elif loss_type_basic == 'l2': self.cri_pix_basic = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type_basic)) self.l_pix_basic_w = train_opt['pixel_weight_basic'] # optimizers wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 self.optim_params = self.__define_grad_params(finetune_type) self.optimizer_G = torch.optim.Adam( self.optim_params, lr=train_opt['lr_G'], weight_decay=wd_G) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------') def feed_data(self, data, need_MR=True, need_MAT=True): self.var_L = data['LR'].to(self.device) # LR self.real_H = data['HR'].to(self.device) # HR if need_MR: self.mid_L = data['MR'].to(self.device) # MR if need_MAT: self.real_blur = data['MAT'].to(self.device) def __define_grad_params(self, finetune_type=None): optim_params = [] if finetune_type == 'sft': for k, v in self.netG.named_parameters(): v.requires_grad = False if k.find('Gate') >= 0: v.requires_grad = True optim_params.append(v) print('we only optimize params: {}'.format(k)) else: for k, v in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) print('params [{:s}] will optimize.'.format(k)) else: print('WARNING: params [{:s}] will not optimize.'.format(k)) for k, v in self.subnet_noise.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) print('params [{:s}] will optimize.'.format(k)) else: print('WARNING: params [{:s}] will not optimize.'.format(k)) for k, v in self.subnet_blur.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) print('params [{:s}] will optimize.'.format(k)) else: print('WARNING: params [{:s}] will not optimize.'.format(k)) return optim_params def optimize_parameters(self, step): self.optimizer_G.zero_grad() self.fake_noise = self.subnet_noise(self.var_L) l_pix_noise = self.l_pix_noise_w * self.cri_pix_noise(self.fake_noise, self.mid_L) l_pix_noise = l_pix_noise + self.cri_pix_reg_noise(self.fake_noise) input_noise = torch.cat((self.var_L, self.fake_noise), 1) self.fake_blur = self.subnet_blur(input_noise) l_pix_blur = self.l_pix_blur_w * self.cri_pix_blur(self.fake_blur*16, self.real_blur) l_pix_blur = l_pix_blur + self.cri_pix_reg_blur(self.fake_blur) input_noise_blur = torch.cat((input_noise, self.fake_blur), 1) self.fake_H = self.netG(input_noise_blur) l_pix_basic = self.l_pix_basic_w * self.cri_pix_basic(self.fake_H, self.real_H) l_pix = l_pix_noise + l_pix_blur + l_pix_basic l_pix.backward() self.optimizer_G.step() self.log_dict['l_pix'] = l_pix.item() def test(self): self.netG.eval() self.subnet_noise.eval() self.subnet_blur.eval() if self.is_train: for v in self.optim_params: v.requires_grad = False else: for k, v in self.netG.named_parameters(): v.requires_grad = False for k, v in self.subnet_noise.named_parameters(): v.requires_grad = False for k, v in self.subnet_blur.named_parameters(): v.requires_grad = False self.fake_noise = self.subnet_noise(self.var_L) input_noise = torch.cat((self.var_L, self.fake_noise), 1) self.fake_blur = self.subnet_blur(input_noise) input_noise_blur = torch.cat((input_noise, self.fake_blur), 1) self.fake_H = self.netG(input_noise_blur) if self.is_train: for v in self.optim_params: v.requires_grad = True else: for k, v in self.netG.named_parameters(): v.requires_grad = True for k, v in self.subnet_noise.named_parameters(): v.requires_grad = True for k, v in self.subnet_blur.named_parameters(): v.requires_grad = True self.netG.train() if self.opt['finetune_type'] in ['basic', 'sft_basic', 'sft', 'sub_sft']: self.subnet_noise.eval() self.subnet_blur.eval() else: self.subnet_noise.train() self.subnet_blur.eval() # def test(self): # self.netG.eval() # for k, v in self.netG.named_parameters(): # v.requires_grad = False # self.fake_H = self.netG(self.var_L) # for k, v in self.netG.named_parameters(): # v.requires_grad = True # self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.var_L.detach()[0].float().cpu() out_dict['MR'] = self.fake_noise.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() if need_HR: out_dict['HR'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): # G s, n = self.get_network_description(self.netG) print('Number of parameters in G: {:,d}'.format(n)) if self.is_train: message = '-------------- Generator --------------\n' + s + '\n' network_path = os.path.join(self.save_dir, '../', 'network.txt') with open(network_path, 'w') as f: f.write(message) # noise subnet s, n = self.get_network_description(self.subnet_noise) print('Number of parameters in noise subnet: {:,d}'.format(n)) message = '\n\n\n-------------- noise subnet --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) # blur subnet s, n = self.get_network_description(self.subnet_blur) print('Number of parameters in blur subnet: {:,d}'.format(n)) message = '\n\n\n-------------- blur subnet --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] load_path_sub_noise = self.opt['path']['pretrain_model_sub_noise'] load_path_sub_blur = self.opt['path']['pretrain_model_sub_blur'] if load_path_G is not None: print('loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG) if load_path_sub_noise is not None: print('loading model for noise subnet [{:s}] ...'.format(load_path_sub_noise)) self.load_network(load_path_sub_noise, self.subnet_noise) if load_path_sub_blur is not None: print('loading model for blur subnet [{:s}] ...'.format(load_path_sub_blur)) self.load_network(load_path_sub_blur, self.subnet_blur) def save(self, iter_label): self.save_network(self.save_dir, self.netG, 'G', iter_label) self.save_network(self.save_dir, self.subnet_noise, 'sub_noise', iter_label) self.save_network(self.save_dir, self.subnet_blur, 'sub_blur', iter_label)
network
至于网络的结构,blur和noise estimation subnetwork都是采用DNCNN的结构,而SR网络采用srresnet
在network中需要定义两个subnetwork
import functoolsimport torchimport torch.nn as nnfrom torch.nn import initimport models.modules.architecture as archimport models.modules.sft_arch as sft_arch##################### initialize####################def weights_init_normal(m, std=0.02): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.normal_(m.weight.data, 0.0, std) if m.bias is not None: m.bias.data.zero_() elif classname.find('Linear') != -1: init.normal_(m.weight.data, 0.0, std) if m.bias is not None: m.bias.data.zero_() elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, std) # BN also uses norm init.constant_(m.bias.data, 0.0)def weights_init_kaiming(m, scale=1): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif classname.find('Linear') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1: init.constant_(m.weight.data, 1.0) init.constant_(m.bias.data, 0.0) # elif classname.find('AdaptiveConvResNorm') != -1: # init.constant_(m.weight.data, 0.0) # if m.bias is not None: # m.bias.data.zero_()def weights_init_orthogonal(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.orthogonal_(m.weight.data, gain=1) if m.bias is not None: m.bias.data.zero_() elif classname.find('Linear') != -1: init.orthogonal_(m.weight.data, gain=1) if m.bias is not None: m.bias.data.zero_() elif classname.find('BatchNorm2d') != -1: init.constant_(m.weight.data, 1.0) init.constant_(m.bias.data, 0.0)def init_weights(net, init_type='kaiming', scale=1, std=0.02): # scale for 'kaiming', std for 'normal'. print('initialization method [{:s}]'.format(init_type)) if init_type == 'normal': weights_init_normal_ = functools.partial(weights_init_normal, std=std) net.apply(weights_init_normal_) elif init_type == 'kaiming': weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale) net.apply(weights_init_kaiming_) elif init_type == 'orthogonal': net.apply(weights_init_orthogonal) else: raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))##################### define network##################### Generatordef define_G(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_G'] which_model = opt_net['which_model_G'] if which_model == 'sr_resnet': # SRResNet netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') elif which_model == 'modulate_sr_resnet': netG = arch.ModulateSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'], gate_conv_bias=opt_net['gate_conv_bias']) elif which_model == 'arcnn': netG = arch.ARCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], ada_ksize=opt_net['ada_ksize']) elif which_model == 'srcnn': netG = arch.SRCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], ada_ksize=opt_net['ada_ksize']) elif which_model == 'denoise_resnet': netG = arch.DenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'], down_scale=opt_net['down_scale'], fea_norm=opt_net['fea_norm'], upsample_norm=opt_net['upsample_norm']) elif which_model == 'modulate_denoise_resnet': netG = arch.ModulateDenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'], gate_conv_bias=opt_net['gate_conv_bias']) elif which_model == 'noise_subnet': netG = arch.NoiseSubNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], norm_type=opt_net['norm_type'], mode=opt_net['mode']) elif which_model == 'cond_denoise_resnet': netG = arch.CondDenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'], down_scale=opt_net['down_scale'], num_classes=opt_net['num_classes'], norm_type=opt_net['norm_type']) elif which_model == 'adabn_denoise_resnet': netG = arch.AdaptiveDenoiseResNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], down_scale=opt_net['down_scale']) elif which_model == 'sft_arch': # SFT-GAN netG = sft_arch.SFT_Net() elif which_model == 'RRDB_net': # RRDB netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) if opt['init_type'] is not None: init_weights(netG, init_type=opt['init_type'], scale=0.1) if gpu_ids: assert torch.cuda.is_available() netG = nn.DataParallel(netG) return netGdef define_sub(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_sub'] which_model = opt_net['which_model_sub'] if which_model == 'noise_subnet': subnet = arch.NoiseSubNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], norm_type=opt_net['norm_type'], mode=opt_net['mode']) else: raise NotImplementedError('subnet model [{:s}] not recognized'.format(which_model)) if gpu_ids: assert torch.cuda.is_available() subnet = nn.DataParallel(subnet) return subnetdef define_sub2(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_sub2'] which_model = opt_net['which_model_sub'] if which_model == 'blur_subnet': subnet = arch.NoiseSubNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], norm_type=opt_net['norm_type'], mode=opt_net['mode']) elif which_model == 'denoise_resnet': subnet = arch.DenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'], down_scale=opt_net['down_scale'], fea_norm=opt_net['fea_norm'], upsample_norm=opt_net['upsample_norm']) else: raise NotImplementedError('subnet model [{:s}] not recognized'.format(which_model)) if gpu_ids: assert torch.cuda.is_available() subnet = nn.DataParallel(subnet) return subnet# Discriminatordef define_D(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_D'] which_model = opt_net['which_model_D'] if which_model == 'discriminator_vgg_128': netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) elif which_model == 'dis_acd': # sft-gan, Auxiliary Classifier Discriminator netD = sft_arch.ACD_VGG_BN_96() elif which_model == 'discriminator_vgg_96': netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) elif which_model == 'discriminator_vgg_192': netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) elif which_model == 'discriminator_vgg_128_SN': netD = arch.Discriminator_VGG_128_SN() else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) init_weights(netD, init_type='kaiming', scale=1) if gpu_ids: netD = nn.DataParallel(netD) return netDdef define_F(opt, use_bn=False): gpu_ids = opt['gpu_ids'] device = torch.device('cuda' if gpu_ids else 'cpu') # pytorch pretrained VGG19-54, before ReLU. if use_bn: feature_layer = 49 else: feature_layer = 34 netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \ use_input_norm=True, device=device) # netF = arch.ResNet101FeatureExtractor(use_input_norm=True, device=device) if gpu_ids: netF = nn.DataParallel(netF) netF.eval() # No need to train return netF
网络结构
import mathimport torchimport torch.nn as nnimport torchvisionimport torch.nn.functional as Ffrom . import block as Bfrom . import spectral_norm as SNfrom . import adaptive_norm as AN##################### Generator####################class SRCNN(nn.Module): def __init__(self, in_nc, out_nc, nf, norm_type='batch', act_type='relu', mode='CNA', ada_ksize=None): super(SRCNN, self).__init__() fea_conv = B.conv_block(in_nc, nf, kernel_size=9, norm_type=norm_type, act_type=act_type, mode=mode , ada_ksize=ada_ksize) mapping_conv = B.conv_block(nf, nf // 2, kernel_size=1, norm_type=norm_type, act_type=act_type, mode=mode, ada_ksize=ada_ksize) HR_conv = B.conv_block(nf // 2, out_nc, kernel_size=5, norm_type=norm_type, act_type=None, mode=mode, ada_ksize=ada_ksize) self.model = B.sequential(fea_conv, mapping_conv, HR_conv) def forward(self, x): x = self.model(x) return xclass ARCNN(nn.Module): def __init__(self, in_nc, out_nc, nf, norm_type='batch', act_type='relu', mode='CNA', ada_ksize=None): super(ARCNN, self).__init__() fea_conv = B.conv_block(in_nc, nf, kernel_size=9, norm_type=norm_type, act_type=act_type, mode=mode , ada_ksize=ada_ksize) conv1 = B.conv_block(nf, nf // 2, kernel_size=7, norm_type=norm_type, act_type=act_type, mode=mode, ada_ksize=ada_ksize) conv2 = B.conv_block(nf // 2, nf // 4, kernel_size=1, norm_type=norm_type, act_type=act_type, mode=mode, ada_ksize=ada_ksize) HR_conv = B.conv_block(nf // 4, out_nc, kernel_size=5, norm_type=norm_type, act_type=None, mode=mode, ada_ksize=ada_ksize) self.model = B.sequential(fea_conv, conv1, conv2, HR_conv) def forward(self, x): x = self.model(x) return xclass SRResNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='batch', act_type='relu', \ mode='NAC', res_scale=1, upsample_mode='upconv'): super(SRResNet, self).__init__() n_upscale = int(math.log(upscale, 2)) if upscale == 3: n_upscale = 1 fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,\ mode=mode, res_scale=res_scale) for _ in range(nb)] LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) if upsample_mode == 'upconv': upsample_block = B.upconv_blcok elif upsample_mode == 'pixelshuffle': upsample_block = B.pixelshuffle_block else: raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) if upscale == 3: upsampler = upsample_block(nf, nf, 3, act_type=act_type) else: upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None) self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)),\ *upsampler, HR_conv0, HR_conv1) def forward(self, x): x = self.model(x) return xclass ModulateSRResNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='sft', act_type='relu', mode='CNA', res_scale=1, upsample_mode='upconv', gate_conv_bias=True, ada_ksize=None): super(ModulateSRResNet, self).__init__() n_upscale = int(math.log(upscale, 2)) if upscale == 3: n_upscale = 1 self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, stride=1) resnet_blocks = [B.TwoStreamSRResNet(nf, nf, nf, norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale, gate_conv_bias=gate_conv_bias, ada_ksize=ada_ksize, input_dim=in_nc) for _ in range(nb)] self.LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode) if norm_type == 'sft': self.LR_norm = AN.GateNonLinearLayer(in_nc, conv_bias=gate_conv_bias) elif norm_type == 'sft_conv': self.LR_norm = AN.MetaLayer(in_nc, conv_bias=gate_conv_bias, kernel_size=ada_ksize) if upsample_mode == 'upconv': upsample_block = B.upconv_blcok elif upsample_mode == 'pixelshuffle': upsample_block = B.pixelshuffle_block else: raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode) if upscale == 3: upsampler = upsample_block(nf, nf, 3, act_type=act_type) else: upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None) self.norm_branch = B.sequential(*resnet_blocks) self.HR_branch = B.sequential(*upsampler, HR_conv0, HR_conv1) def forward(self, x): fea = self.fea_conv(x[0]) fea_res_block, _ = self.norm_branch((fea, x[1])) fea_LR = self.LR_conv(fea_res_block) res = self.LR_norm((fea_LR, x[1])) out = self.HR_branch(fea+res) return outclass DenoiseResNet(nn.Module): """ jingwen's addition denoise Resnet """ def __init__(self, in_nc, out_nc, nf, nb, upscale=1, norm_type='batch', act_type='relu', mode='CNA', res_scale=1, upsample_mode='upconv', ada_ksize=None, down_scale=2, fea_norm=None, upsample_norm=None): super(DenoiseResNet, self).__init__() n_upscale = int(math.log(down_scale, 2)) if down_scale == 3: n_upscale = 1 fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=fea_norm, act_type=None, stride=down_scale, ada_ksize=ada_ksize) resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale, ada_ksize=ada_ksize) for _ in range(nb)] LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode , ada_ksize=ada_ksize) # LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode # , ada_ksize=ada_ksize) if upsample_mode == 'upconv': upsample_block = B.upconv_blcok elif upsample_mode == 'pixelshuffle': upsample_block = B.pixelshuffle_block else: raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode) if down_scale == 3: upsampler = upsample_block(nf, nf, 3, act_type=act_type, norm_type=upsample_norm, ada_ksize=ada_ksize) else: upsampler = [upsample_block(nf, nf, act_type=act_type, norm_type=upsample_norm, ada_ksize=ada_ksize) for _ in range(n_upscale)] HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=upsample_norm, act_type=act_type, ada_ksize=ada_ksize) HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=upsample_norm, act_type=None, ada_ksize=ada_ksize) self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)), *upsampler, HR_conv0, HR_conv1) def forward(self, x): x = self.model(x) return xclass ModulateDenoiseResNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, upscale=1, norm_type='sft', act_type='relu', mode='CNA', res_scale=1, upsample_mode='upconv', gate_conv_bias=True, ada_ksize=None): super(ModulateDenoiseResNet, self).__init__() self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, stride=2) resnet_blocks = [B.TwoStreamSRResNet(nf, nf, nf, norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale, gate_conv_bias=gate_conv_bias, ada_ksize=ada_ksize, input_dim=in_nc) for _ in range(nb)] LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode) if norm_type == 'sft': LR_norm = AN.GateNonLinearLayer(in_nc, conv_bias=gate_conv_bias) elif norm_type == 'sft_conv': LR_norm = AN.MetaLayer(in_nc, conv_bias=gate_conv_bias, kernel_size=ada_ksize) if upsample_mode == 'upconv': upsample_block = B.upconv_blcok elif upsample_mode == 'pixelshuffle': upsample_block = B.pixelshuffle_block else: raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode) upsampler = upsample_block(nf, nf, act_type=act_type) HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None) self.norm_branch = B.sequential(*resnet_blocks) self.LR_conv = LR_conv self.LR_norm = LR_norm self.HR_branch = B.sequential(upsampler, HR_conv0, HR_conv1) def forward(self, x): fea = self.fea_conv(x[0]) fea_res_block, _ = self.norm_branch((fea, x[1])) fea_LR = self.LR_conv(fea_res_block) res = self.LR_norm((fea_LR, x[1])) out = self.HR_branch(fea+res) return outclass NoiseSubNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, norm_type='batch', act_type='relu', mode='CNA'): super(NoiseSubNet, self).__init__() degration_block = [B.conv_block(in_nc, nf, kernel_size=3, norm_type=norm_type, act_type=act_type, mode=mode)] degration_block.extend([B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=act_type, mode=mode) for _ in range(15)]) degration_block.append(B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, mode=mode)) self.degration_block = B.sequential(*degration_block) def forward(self, x): deg_estimate = self.degration_block(x) return deg_estimateclass CondDenoiseResNet(nn.Module): """ jingwen's addition denoise Resnet """ def __init__(self, in_nc, out_nc, nf, nb, upscale=1, res_scale=1, down_scale=2, num_classes=1, ada_ksize=None ,upsample_mode='upconv', act_type='relu', norm_type='cond_adaptive_conv_res'): super(CondDenoiseResNet, self).__init__() n_upscale = int(math.log(down_scale, 2)) if down_scale == 3: n_upscale = 1 self.fea_conv = nn.Conv2d(in_nc, nf, kernel_size=3, stride=down_scale, padding=1) resnet_blocks = [B.CondResNetBlock(nf, nf, nf, num_classes=num_classes, ada_ksize=ada_ksize, norm_type=norm_type, act_type=act_type) for _ in range(nb)] self.resnet_blocks = B.sequential(*resnet_blocks) self.LR_conv = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1) if norm_type == 'cond_adaptive_conv_res': self.cond_adaptive = AN.CondAdaptiveConvResNorm(nf, num_classes=num_classes) elif norm_type == "interp_adaptive_conv_res": self.cond_adaptive = AN.InterpAdaptiveResNorm(nf, ada_ksize) elif norm_type == "cond_instance": self.cond_adaptive = AN.CondInstanceNorm2d(nf, num_classes=num_classes) elif norm_type == "cond_transform_res": self.cond_adaptive = AN.CondResTransformer(nf, ada_ksize, num_classes=num_classes) if upsample_mode == 'upconv': upsample_block = B.upconv_blcok elif upsample_mode == 'pixelshuffle': upsample_block = B.pixelshuffle_block else: raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode) if down_scale == 3: upsampler = upsample_block(nf, nf, 3, act_type=act_type) else: upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None) self.upsample = B.sequential(*upsampler, HR_conv0, HR_conv1) def forward(self, x, y): # the first feature extraction fea = self.fea_conv(x) fea1, _ = self.resnet_blocks((fea, y)) fea2 = self.LR_conv(fea1) fea3 = self.cond_adaptive(fea2, y) # res out = self.upsample(fea3 + fea) return outclass AdaptiveDenoiseResNet(nn.Module): """ jingwen's addition adabn """ def __init__(self, in_nc, nf, nb, upscale=1, res_scale=1, down_scale=2): super(AdaptiveDenoiseResNet, self).__init__() self.fea_conv = nn.Conv2d(in_nc, nf, kernel_size=3, stride=down_scale, padding=1) resnet_blocks = [B.AdaptiveResNetBlock(nf, nf, nf, res_scale=res_scale) for _ in range(nb)] self.resnet_blocks = B.sequential(*resnet_blocks) self.LR_conv = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1) self.batch_norm = nn.BatchNorm2d(nf, affine=True, track_running_stats=True, momentum=0) def forward(self, x): fea_list = [self.fea_conv(data.unsqueeze_(0)) for data in x] fea_resblock_list = self.resnet_blocks(fea_list) fea_LR_list = [self.LR_conv(fea) for fea in fea_resblock_list] fea_mean, fea_var = B.computing_mean_variance(fea_LR_list) batch_norm_dict = self.batch_norm.state_dict() batch_norm_dict['running_mean'] = fea_mean batch_norm_dict['running_var'] = fea_var self.batch_norm.load_state_dict(batch_norm_dict) return None
experiment
发表评论
最新留言
感谢大佬
[***.8.128.20]2025年04月17日 18时54分19秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
一个JAVA应用启动缓慢问题排查 --来自jdk securerandom 的问候
2021-05-09
spring-boot-2.0.3之redis缓存实现,不是你想的那样哦!
2021-05-09
httprunner学习23-加解密
2021-05-09
jenkins学习13-凭据管理(删除多余的凭据)
2021-05-09
有道云笔记 同步到我的博客园
2021-05-09
阿里云“网红"运维工程师白金:做一个平凡的圆梦人
2021-05-09
AnalyticDB for PostgreSQL 6.0 新特性介绍
2021-05-09
Alibaba Cloud Linux 2 LTS 正式发布,提供更高性能和更多保障!
2021-05-09
李笑来必读书籍整理
2021-05-09
vue书籍整理
2021-05-09
记Java中有关内存的简单认识
2021-05-09
Mybatis配置解析
2021-05-09
http头部 Expect
2021-05-09
Hadoop(十六)之使用Combiner优化MapReduce
2021-05-09
C#实现outlook自动签名
2021-05-09
MySQL 5.5 My.cnf 模版
2021-05-09
使用mysqladmin ext了解MySQL运行状态【转】
2021-05-09
《机器学习Python实现_10_06_集成学习_boosting_gbdt分类实现》
2021-05-09
精讲响应式WebClient第2篇-GET请求阻塞与非阻塞调用方法详解
2021-05-09