
实验记录
发布日期:2021-05-10 14:19:09
浏览次数:23
分类:精选文章
本文共 21417 字,大约阅读时间需要 71 分钟。
1、basic model 训练采用X4的数据,noise的范围为【0,15】;
2、subnetwork训练的数据应该是采用noise范围为【0,66】;不采用tvloss,同样采用LR进行训练,
3、finutune两个模型。基于BN层。X4与【0,66】的数据;
python train.py -opt options/train/train_sr.json
nosie level estimation 网络的setting
{ "name": "subnetwork_X1_0_66" // please remove "debug_" during training , "tb_logger_dir": "sft666" , "use_tb_logger": true , "model":"sr" , "scale": 1 , "crop_scale": 0 , "gpu_ids": [3,4]// , "init_type": "kaiming"//// , "finetune_type": "sft"// , "init_norm_type": "zero" , "datasets": { "train": { "name": "DIV2K" , "mode": "LRHR" , "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4" , "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4_noise_66" , "subset_file": null , "use_shuffle": true , "n_workers": 8 , "batch_size": 24 // 16 , "HR_size": 96 // 128 | 192 | 96 , "noise_gt": true//residual for the noise , "use_flip": true , "use_rot": true } , "val": { "name": "val_set5_c03s08_LR_mod4" , "mode": "LRHR" , "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4" , "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4_noise_66" , "noise_gt": true//residual for the noise } } , "path": { "root": "/home/guanwp/Blind_Restoration-master/sft666" , "pretrain_model_G": null }// , "network_G": { "which_model_G": "denoise_resnet" // RRDB_net | sr_resnet | modulate_denoise_resnet |noise_subnet // , "norm_type": "adaptive_conv_res" //, "norm_type": "batch" , "mode": "CNA", "nf": 64, "nb": 16, "in_nc": 3, "out_nc": 3 // , "gc": 32 , "group": 1 ,"down_scale": 2//denoise srresnet // , "gate_conv_bias": true // , "ada_ksize": 1 // , "num_classes": 2 }// , "network_G": {// "which_model_G": "srcnn" // RRDB_net | sr_resnet , "norm_type": null// , "norm_type": "adaptive_conv_res"// , "mode": "CNA"// , "nf": 64// , "in_nc": 1// , "out_nc": 1// , "ada_ksize": 5// } , "train": {// "lr_G": 1e-3 "lr_G": 1e-4 , "lr_scheme": "MultiStepLR"// , "lr_steps": [200000, 400000, 600000, 800000] , "lr_steps": [500000] , "lr_gamma": 0.1// , "lr_gamma": 0.5 , "pixel_criterion": "l2" , "pixel_criterion_reg": "tv" , "pixel_weight": 1.0 , "val_freq": 2e3 , "manual_seed": 0 , "niter": 1e6// , "niter": 6e5 } , "logger": { "print_freq": 200 , "save_checkpoint_freq": 2e3 }}
basic model的setting
{ "name": "subnetwork_X1_0_66" // please remove "debug_" during training , "tb_logger_dir": "sft666" , "use_tb_logger": true , "model":"sr" , "scale": 4 , "crop_scale": 0 , "gpu_ids": [3,4]// , "init_type": "kaiming"//// , "finetune_type": "sft"// , "init_norm_type": "zero" , "datasets": { "train": { "name": "DIV2K" , "mode": "LRHR" , "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub" , "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4_noise_15" , "subset_file": null , "use_shuffle": true , "n_workers": 8 , "batch_size": 24 // 16 , "HR_size": 96 // 128 | 192 | 96 //, "noise_gt": true//residual for the noise , "use_flip": true , "use_rot": true } , "val": { "name": "val_set5_c03s08_LR_mod4" , "mode": "LRHR" , "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5" , "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4_noise_15" //, "noise_gt": true//residual for the noise } } , "path": { "root": "/home/guanwp/Blind_Restoration-master/sft666" , "pretrain_model_G": null }// , "network_G": { "which_model_G": "sr_resnet" // RRDB_net | sr_resnet | modulate_denoise_resnet |noise_subnet // , "norm_type": "adaptive_conv_res" , "norm_type": null//"batch" , "mode": "CNA", "nf": 64, "nb": 16, "in_nc": 3, "out_nc": 3 // , "gc": 32 , "group": 1 //,"down_scale": 2//denoise srresnet // , "gate_conv_bias": true // , "ada_ksize": 1 // , "num_classes": 2 }// , "network_G": {// "which_model_G": "srcnn" // RRDB_net | sr_resnet , "norm_type": null// , "norm_type": "adaptive_conv_res"// , "mode": "CNA"// , "nf": 64// , "in_nc": 1// , "out_nc": 1// , "ada_ksize": 5// } , "train": {// "lr_G": 1e-3 "lr_G": 1e-4 , "lr_scheme": "MultiStepLR"// , "lr_steps": [200000, 400000, 600000, 800000] , "lr_steps": [500000] , "lr_gamma": 0.1// , "lr_gamma": 0.5 , "pixel_criterion": "l2" , "pixel_criterion_reg": "tv" , "pixel_weight": 1.0 , "val_freq": 2e3 , "manual_seed": 0 , "niter": 1e6// , "niter": 6e5 } , "logger": { "print_freq": 200 , "save_checkpoint_freq": 2e3 }}
模型训练好的结果如下:
好,接下来训练SFT-layer。由于在训练SFT-layer的时候,原来的basicmodel是没有BN层的,当然也没有SFT-layer层,就是网络的结构是发生了变化的,故此需要用transfer_params_degration.py文件进行模型的转换
import torchfrom torch.nn import initfrom collections import OrderedDict# pretrained_net = torch.load('../../baselines/experiments/bicx3_nonorm_denoise_resnet_DIV2K/models/794000_G.pth')# pretrained_net = torch.load('../../baselines_jpeg/experiments/JPEG80_gray_nonorm_denoise_resnet_DIV2K/models/964000_G.pth')# pretrained_net = torch.load('../../noise_from15to75/experiments/gaussian_from15to75_resnet_denoise_DIV2K/models/986000_G.pth')# pretrained_net = torch.load('/home/jwhe/workspace/BasicSR_v3/experiments/pretrained_models/noise_c16s06/bicx4_nonorm_denoise_resnet_DIV2K/992000_G.pth')# pretrained_net = torch.load('/home/jwhe/workspace/BasicSR_v3/sr_c16s06/experiments/LR_srx4_c16s06_resnet_denoise_DIV2K/models/704000_G.pth')pretrained_net = torch.load('/home/jwhe/workspace/BasicSR_v3/sr/experiments/LR_srx4_resnet_denoise_DIV2K/models/516000_G.pth')# should run train debug mode first to get an initial model# pretrained_net_degradation = torch.load('../../noise_subnet/experiments/noise75_subnet/models/84000_G.pth')# pretrained_net_degradation = torch.load('../../noise_subnet/experiments/noise15_subnet/models/34000_G.pth')# reference_net = torch.load('../../noise_estimate/experiments/finetune_75_sft_64_nores_noise_estimate_15_denoise_resnet_DIV2K/models/20000_G.pth')adaptive_norm_net = OrderedDict()# initialize the norm with value 0# for k, v in adaptive_norm_net.items():# if 'gamma' in k:# print(k, 'gamma')# v.fill_(0)# elif 'beta' in k:# print(k, 'beta')# v.fill_(0)# for k, v in adaptive_norm_net.items():# if 'gamma' in k:# print(k, 'gamma')# v.fill_(0)# elif 'beta' in k:# print(k, 'beta')# v.fill_(0)adaptive_norm_net['fea_conv.0.weight'] = pretrained_net['model.0.weight']adaptive_norm_net['fea_conv.0.bias'] = pretrained_net['model.0.bias']# residual blocksfor i in range(16): adaptive_norm_net['norm_branch.{:d}.conv_block0.0.weight'.format(i)] = pretrained_net['model.1.sub.{:d}.res.0.weight'.format(i)] adaptive_norm_net['norm_branch.{:d}.conv_block0.0.bias'.format(i)] = pretrained_net['model.1.sub.{:d}.res.0.bias'.format(i)] adaptive_norm_net['norm_branch.{:d}.conv_block1.0.weight'.format(i)] = pretrained_net['model.1.sub.{:d}.res.2.weight'.format(i)] adaptive_norm_net['norm_branch.{:d}.conv_block1.0.bias'.format(i)] = pretrained_net['model.1.sub.{:d}.res.2.bias'.format(i)]adaptive_norm_net['LR_conv.0.weight'] = pretrained_net['model.1.sub.16.weight']adaptive_norm_net['LR_conv.0.bias'] = pretrained_net['model.1.sub.16.bias']# HR# adaptive_norm_net['HR_branch.0.weight'] = pretrained_net['model.2.weight']# adaptive_norm_net['HR_branch.0.bias'] = pretrained_net['model.2.bias']# adaptive_norm_net['HR_branch.3.weight'] = pretrained_net['model.5.weight']# adaptive_norm_net['HR_branch.3.bias'] = pretrained_net['model.5.bias']# adaptive_norm_net['HR_branch.5.weight'] = pretrained_net['model.7.weight']# adaptive_norm_net['HR_branch.5.bias'] = pretrained_net['model.7.bias']adaptive_norm_net['HR_branch.0.weight'] = pretrained_net['model.2.weight']adaptive_norm_net['HR_branch.0.bias'] = pretrained_net['model.2.bias']adaptive_norm_net['HR_branch.3.weight'] = pretrained_net['model.5.weight']adaptive_norm_net['HR_branch.3.bias'] = pretrained_net['model.5.bias']adaptive_norm_net['HR_branch.6.weight'] = pretrained_net['model.8.weight']adaptive_norm_net['HR_branch.6.bias'] = pretrained_net['model.8.bias']adaptive_norm_net['HR_branch.8.weight'] = pretrained_net['model.10.weight']adaptive_norm_net['HR_branch.8.bias'] = pretrained_net['model.10.bias']# for k, v in pretrained_net_degradation.items():# adaptive_norm_net[k] = vprint('OK. \n Saving model...')# torch.save(adaptive_norm_net, '../../experiments/pretrained_models/jpeg_estimation/JPEG80_gray_nonorm_denoise_resnet_DIV2K/jpeg80_964000.pth')# torch.save(adaptive_norm_net, '../../experiments/pretrained_models/50to15_models/gaussian75_nonorm_denoise_resnet_DIV2K/noise_CNA_adaptive_988000.pth')# torch.save(adaptive_norm_net, '../../experiments/pretrained_models/basic_model/gaussian_from15to75_resnet_denoise_DIV2K/from15to75_basicmodel_986000.pth')# torch.save(adaptive_norm_net, '../../experiments/pretrained_models/sr_c16s06/LR_srx4_c16s06_resnet_denoise_DIV2K/c16s06_basicmodel_704000.pth')torch.save(adaptive_norm_net, '../../experiments/pretrained_models/sr/LR_srx4_resnet_denoise_DIV2K/basicmodel_516000.pth')
python train.py -opt options/train/train_sub_sr.json
setting(注意不要noise_groundth)
{ "name": "finetune_noiseestimation_sftlayer_15to66" // please remove "debug_" during training , "tb_logger_dir": "sft666" , "use_tb_logger": true , "model":"sr_sub" , "scale": 4 , "crop_scale": 4 , "gpu_ids": [4,5]// , "init_type": "kaiming"//// , "finetune_type": "sft_basic" //sft | basic// , "init_norm_type": "zero" , "datasets": { "train": { "name": "DIV2K" , "mode": "LRHR" , "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub" , "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4_noise_66" , "subset_file": null , "use_shuffle": true , "n_workers": 8 , "batch_size": 24 // 16 , "HR_size": 96 // 128 | 192 | 96 //, "noise_gt": true , "use_flip": true , "use_rot": true }// , "val": { "name": "val_set5_x4_c03s08_mod4", "mode": "LRHR", "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5", "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4_noise_66" // , "noise_gt": true }//// , "val": {// "name": "val_set5_x3_gray_mod6"// , "mode": "LRHR"// , "dataroot_HR": "/media/sdc/jwhe/BasicSR_v2/data/val/Set5_val/mod6/Set5_gray_mod6"// , "dataroot_LR": "/media/sdc/jwhe/BasicSR_v2/data/val/Set5_val/mod6/Set5_gray_bicx3"// } } , "path": { "root": "/home/guanwp/Blind_Restoration-master/sft666" , "pretrain_model_G": "/home/guanwp/Blind_Restoration-master/sft/686000_G.pth" , "pretrain_model_sub": "/home/guanwp/Blind_Restoration-master/sft666/experiments/subnetwork_X1_0_66/models/528000_G.pth" } , "network_G": { "which_model_G": "modulate_sr_resnet" // RRDB_net | sr_resnet | modulate_sr_resnet// , "norm_type": "adaptive_conv_res" , "norm_type": "sft"// , "norm_type": null , "mode": "CNA" , "nf": 64 , "nb": 16 , "in_nc": 3 , "out_nc": 3// , "gc": 32 , "group": 1 , "gate_conv_bias": true } , "network_sub": { "which_model_sub": "denoise_resnet" // RRDB_net | sr_resnet | modulate_denoise_resnet |noise_subnet // , "norm_type": "adaptive_conv_res" , //"norm_type": "batch", "mode": "CNA", "nf": 64 , "nb": 16 , "in_nc": 3, "out_nc": 3, "group": 1, "down_scale": 2 } , "train": {// "lr_G": 1e-3 "lr_G": 1e-4 , "lr_scheme": "MultiStepLR"// , "lr_steps": [200000, 400000, 600000, 800000] , "lr_steps": [500000] , "lr_gamma": 0.1// , "lr_gamma": 0.5 , "pixel_criterion_basic": "l2" , "pixel_criterion_noise": "l2" , "pixel_criterion_reg_noise": "tv" , "pixel_weight_basic": 1.0 , "pixel_weight_noise": 1.0 , "val_freq": 2e3 , "manual_seed": 0 , "niter": 1e6// , "niter": 6e5 } , "logger": { "print_freq": 200 , "save_checkpoint_freq": 2e3 }}
首先给出sub_srmodel,关键部分应该是在optimize_parameters,可以看到数据的流向。与级联网络不一样的地方是,这里不需要concat
import osfrom collections import OrderedDictimport torchimport torch.nn as nnfrom torch.optim import lr_schedulerimport models.networks as networksfrom .base_model import BaseModelfrom .modules.loss import TVLossclass SRModel(BaseModel): def __init__(self, opt): super(SRModel, self).__init__(opt) train_opt = opt['train'] finetune_type = opt['finetune_type'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) self.subnet = networks.define_sub(opt).to(self.device) self.load() if self.is_train: self.netG.train() self.subnet.train() # self.subnet.eval() # loss loss_type_noise = train_opt['pixel_criterion_noise'] if loss_type_noise == 'l1': self.cri_pix_noise = nn.L1Loss().to(self.device) elif loss_type_noise == 'l2': self.cri_pix_noise = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type_noise)) self.l_pix_noise_w = train_opt['pixel_weight_noise'] loss_reg_noise = train_opt['pixel_criterion_reg_noise'] if loss_reg_noise == 'tv': self.cri_pix_reg_noise = TVLoss(0.00001).to(self.device) loss_type_basic = train_opt['pixel_criterion_basic'] if loss_type_basic == 'l1': self.cri_pix_basic = nn.L1Loss().to(self.device) elif loss_type_basic == 'l2': self.cri_pix_basic = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type_basic)) self.l_pix_basic_w = train_opt['pixel_weight_basic'] # optimizers wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 self.optim_params = self.__define_grad_params(finetune_type) self.optimizer_G = torch.optim.Adam( self.optim_params, lr=train_opt['lr_G'], weight_decay=wd_G) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------') def feed_data(self, data, need_HR=True): self.var_L = data['LR'].to(self.device) # LR #self.real_H = data['HR'].to(self.device) # HR #self.mid_L = data['MR'].to(self.device) # MR def __define_grad_params(self, finetune_type=None): optim_params = [] if finetune_type == 'sft': for k, v in self.netG.named_parameters(): v.requires_grad = False if k.find('Gate') >= 0: v.requires_grad = True optim_params.append(v) print('we only optimize params: {}'.format(k)) elif finetune_type == 'sub_sft': for k, v in self.netG.named_parameters(): v.requires_grad = False if k.find('Gate') >= 0: v.requires_grad = True optim_params.append(v) print('we only optimize params: {}'.format(k)) for k, v in self.subnet.named_parameters(): # can optimize for a part of the model v.requires_grad = False if k.find('degration') >= 0: v.requires_grad = True optim_params.append(v) print('we only optimize params: {}'.format(k)) elif finetune_type == 'basic' or finetune_type == 'sft_basic': for k, v in self.netG.named_parameters(): v.requires_grad = True optim_params.append(v) print('we only optimize params: {}'.format(k)) for k, v in self.subnet.named_parameters(): v.requires_grad = False else: for k, v in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) print('params [{:s}] will optimize.'.format(k)) else: print('WARNING: params [{:s}] will not optimize.'.format(k)) for k, v in self.subnet.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) print('params [{:s}] will optimize.'.format(k)) else: print('WARNING: params [{:s}] will not optimize.'.format(k)) return optim_params def optimize_parameters(self, step): self.optimizer_G.zero_grad() self.fake_noise = self.subnet(self.var_L) l_pix_noise = self.l_pix_noise_w * self.cri_pix_noise(self.fake_noise, self.mid_L) l_pix_noise = l_pix_noise + self.cri_pix_reg_noise(self.fake_noise) # self.fake_H = self.netG(torch.cat((self.var_L, self.fake_noise), 1)) self.fake_H = self.netG((self.var_L, self.fake_noise)) l_pix_basic = self.l_pix_basic_w * self.cri_pix_basic(self.fake_H, self.real_H) l_pix = l_pix_noise + l_pix_basic l_pix.backward() # self.fake_noise = self.subnet(self.var_L) # # self.fake_H = self.netG(torch.cat((self.var_L, self.fake_noise), 1)) # self.fake_H = self.netG((self.var_L, self.fake_noise)) # l_pix = self.l_pix_basic_w * self.cri_pix_basic(self.fake_H, self.real_H) # l_pix.backward() self.optimizer_G.step() self.log_dict['l_pix'] = l_pix.item() def test(self): self.netG.eval() self.subnet.eval() if self.is_train: for v in self.optim_params: v.requires_grad = False else: for k, v in self.netG.named_parameters(): v.requires_grad = False for k, v in self.subnet.named_parameters(): v.requires_grad = False self.fake_noise = self.subnet(self.var_L) self.fake_H = self.netG(torch.cat((self.var_L, self.fake_noise*2.6), 1)) #self.fake_H = self.netG((self.var_L, self.fake_noise)) if self.is_train: for v in self.optim_params: v.requires_grad = True else: for k, v in self.netG.named_parameters(): v.requires_grad = True for k, v in self.subnet.named_parameters(): v.requires_grad = True self.netG.train() self.subnet.train() # self.subnet.eval() # def test(self): # self.netG.eval() # for k, v in self.netG.named_parameters(): # v.requires_grad = False # self.fake_H = self.netG(self.var_L) # for k, v in self.netG.named_parameters(): # v.requires_grad = True # self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.var_L.detach()[0].float().cpu() out_dict['MR'] = self.fake_noise.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() if need_HR: out_dict['HR'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): # G s, n = self.get_network_description(self.netG) print('Number of parameters in G: {:,d}'.format(n)) if self.is_train: message = '-------------- Generator --------------\n' + s + '\n' network_path = os.path.join(self.save_dir, '../', 'network.txt') with open(network_path, 'w') as f: f.write(message) # subnet s, n = self.get_network_description(self.subnet) print('Number of parameters in subnet: {:,d}'.format(n)) message = '\n\n\n-------------- subnet --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] load_path_sub = self.opt['path']['pretrain_model_sub'] if load_path_G is not None: print('loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG) if load_path_sub is not None: print('loading model for subnet [{:s}] ...'.format(load_path_sub)) self.load_network(load_path_sub, self.subnet) def save(self, iter_label): self.save_network(self.save_dir, self.netG, 'G', iter_label) self.save_network(self.save_dir, self.subnet, 'sub', iter_label)
关于SFT-layer部分可以参考博文《》
通过SFT-layer使得basic model从【0,15】到【0,66】
结果:
发表评论
最新留言
做的很好,不错不错
[***.243.131.199]2025年04月20日 06时23分45秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
高项论文——论信息系统的文档编制 解答要点
2021-05-11