diff --git a/codes/models/SPSR_model.py b/codes/models/SPSR_model.py index f26f2d13..fe55e193 100644 --- a/codes/models/SPSR_model.py +++ b/codes/models/SPSR_model.py @@ -7,84 +7,14 @@ import torch.nn as nn from torch.optim import lr_scheduler from apex import amp -import models.SPSR_networks as networks +import models.networks as networks from .base_model import BaseModel -from models.SPSR_modules.loss import GANLoss +from models.loss import GANLoss import torchvision.utils as utils +from .archs.SPSR_arch import ImageGradient, ImageGradientNoPadding logger = logging.getLogger('base') -import torch.nn.functional as F - -class Get_gradient(nn.Module): - def __init__(self): - super(Get_gradient, self).__init__() - kernel_v = [[0, -1, 0], - [0, 0, 0], - [0, 1, 0]] - kernel_h = [[0, 0, 0], - [-1, 0, 1], - [0, 0, 0]] - kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) - kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) - self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda() - self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda() - - def forward(self, x): - x0 = x[:, 0] - x1 = x[:, 1] - x2 = x[:, 2] - x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=2) - x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=2) - - x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=2) - x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=2) - - x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=2) - x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=2) - - x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6) - x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6) - x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6) - - x = torch.cat([x0, x1, x2], dim=1) - return x - -class Get_gradient_nopadding(nn.Module): - def __init__(self): - super(Get_gradient_nopadding, self).__init__() - kernel_v = [[0, -1, 0], - [0, 0, 0], - [0, 1, 0]] - kernel_h = [[0, 0, 0], - [-1, 0, 1], - [0, 0, 0]] - kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) - kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) - self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda() - self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda() - - def forward(self, x): - x0 = x[:, 0] - x1 = x[:, 1] - x2 = x[:, 2] - x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding = 1) - x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding = 1) - - x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding = 1) - x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding = 1) - - x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding = 1) - x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding = 1) - - x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6) - x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6) - x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6) - - x = torch.cat([x0, x1, x2], dim=1) - return x - - class SPSRModel(BaseModel): def __init__(self, opt): super(SPSRModel, self).__init__(opt) @@ -93,8 +23,8 @@ class SPSRModel(BaseModel): # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: - self.netD = networks.define_D(opt).to(self.device) # D - self.netD_grad = networks.define_D_grad(opt).to(self.device) # D_grad + self.netD = networks.define_D(opt).to(self.device) # D + self.netD_grad = networks.define_D(opt).to(self.device) # D_grad self.netG.train() self.netD.train() self.netD_grad.train() @@ -142,8 +72,8 @@ class SPSRModel(BaseModel): self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 # Branch_init_iters - self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt['Branch_pretrain'] else 0 - self.Branch_init_iters = train_opt['Branch_init_iters'] if train_opt['Branch_init_iters'] else 1 + self.branch_pretrain = train_opt['branch_pretrain'] if train_opt['branch_pretrain'] else 0 + self.branch_init_iters = train_opt['branch_init_iters'] if train_opt['branch_init_iters'] else 1 # gradient_pixel_loss if train_opt['gradient_pixel_weight'] > 0: @@ -217,8 +147,8 @@ class SPSRModel(BaseModel): raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() - self.get_grad = Get_gradient() - self.get_grad_nopadding = Get_gradient_nopadding() + self.get_grad = ImageGradient() + self.get_grad_nopadding = ImageGradientNoPadding() def feed_data(self, data, need_HR=True): # LR @@ -232,6 +162,12 @@ class SPSRModel(BaseModel): def optimize_parameters(self, step): + # Some generators have variants depending on the current step. + if hasattr(self.netG.module, "update_for_step"): + self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) + if hasattr(self.netD.module, "update_for_step"): + self.netD.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) + # G for p in self.netD.parameters(): p.requires_grad = False @@ -239,9 +175,8 @@ class SPSRModel(BaseModel): for p in self.netD_grad.parameters(): p.requires_grad = False - - if(self.Branch_pretrain): - if(step < self.Branch_init_iters): + if(self.branch_pretrain): + if(step < self.branch_init_iters): for k,v in self.netG.named_parameters(): if 'f_' not in k : v.requires_grad=False @@ -250,7 +185,6 @@ class SPSRModel(BaseModel): if 'f_' not in k : v.requires_grad=True - self.optimizer_G.zero_grad() self.fake_H_branch = [] @@ -361,43 +295,49 @@ class SPSRModel(BaseModel): os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "gen_grad"), exist_ok=True) # fed_LQ is not chunked. utils.save_image(self.var_H[0].cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,))) utils.save_image(self.var_L[0].cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,))) utils.save_image(self.fake_H[0].cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,))) + utils.save_image(self.grad_LR[0].cpu(), os.path.join(sample_save_path, "gen_grad", "%05i.png" % (step,))) # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: # G if self.cri_pix: - self.log_dict['l_g_pix'] = l_g_pix.item() + self.add_log_entry('l_g_pix', l_g_pix.item()) if self.cri_fea: - self.log_dict['l_g_fea'] = l_g_fea.item() + self.add_log_entry('l_g_fea', l_g_fea.item()) if self.l_gan_w > 0: - self.log_dict['l_g_gan'] = l_g_gan.item() + self.add_log_entry('l_g_gan', l_g_gan.item()) if self.cri_pix_branch: #branch pixel loss - self.log_dict['l_g_pix_grad_branch'] = l_g_pix_grad_branch.item() + self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.item()) if self.l_gan_w > 0: - # D - self.log_dict['l_d_real'] = l_d_real.item() - self.log_dict['l_d_fake'] = l_d_fake.item() + self.add_log_entry('l_d_real', l_d_real.item()) + self.add_log_entry('l_d_fake', l_d_fake.item()) + self.add_log_entry('l_d_real_grad', l_d_real_grad.item()) + self.add_log_entry('l_d_fake_grad', l_d_fake_grad.item()) + self.add_log_entry('D_real', torch.mean(pred_d_real.detach())) + self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) + self.add_log_entry('D_real_grad', torch.mean(pred_d_real_grad.detach())) + self.add_log_entry('D_fake_grad', torch.mean(pred_d_fake_grad.detach())) - # D_grad - self.log_dict['l_d_real_grad'] = l_d_real_grad.item() - self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item() - - if self.opt['train']['gan_type'] == 'wgan-gp': - self.log_dict['l_d_gp'] = l_d_gp.item() - # D outputs - self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) - self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) - - # D_grad outputs - self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach()) - self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach()) + # Allows the log to serve as an easy-to-use rotating buffer. + def add_log_entry(self, key, value): + key_it = "%s_it" % (key,) + log_rotating_buffer_size = 50 + if key not in self.log_dict.keys(): + self.log_dict[key] = [] + self.log_dict[key_it] = 0 + if len(self.log_dict[key]) < log_rotating_buffer_size: + self.log_dict[key].append(value) + else: + self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value + self.log_dict[key_it] += 1 def test(self): self.netG.eval() @@ -413,8 +353,21 @@ class SPSRModel(BaseModel): self.netG.train() + # Fetches a summary of the log. def get_current_log(self, step): - return self.log_dict + return_log = {} + for k in self.log_dict.keys(): + if not isinstance(self.log_dict[k], list): + continue + return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k]) + + # Some generators can do their own metric logging. + if hasattr(self.netG.module, "get_debug_values"): + return_log.update(self.netG.module.get_debug_values(step)) + if hasattr(self.netD.module, "get_debug_values"): + return_log.update(self.netD.module.get_debug_values(step)) + + return return_log def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() @@ -470,6 +423,10 @@ class SPSRModel(BaseModel): if self.opt['is_train'] and load_path_D is not None: logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD) + load_path_D_grad = self.opt['path']['pretrain_model_D_grad'] + if self.opt['is_train'] and load_path_D_grad is not None: + logger.info('Loading pretrained model for D_grad [{:s}] ...'.format(load_path_D_grad)) + self.load_network(load_path_D_grad, self.netD_grad) def compute_fea_loss(self, real, fake): if self.cri_fea is None: diff --git a/codes/models/SPSR_modules/__init__.py b/codes/models/SPSR_modules/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/models/SPSR_modules/architecture.py b/codes/models/SPSR_modules/architecture.py deleted file mode 100644 index 6928aa0c..00000000 --- a/codes/models/SPSR_modules/architecture.py +++ /dev/null @@ -1,366 +0,0 @@ -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision -from . import block as B - - -class Get_gradient_nopadding(nn.Module): - def __init__(self): - super(Get_gradient_nopadding, self).__init__() - kernel_v = [[0, -1, 0], - [0, 0, 0], - [0, 1, 0]] - kernel_h = [[0, 0, 0], - [-1, 0, 1], - [0, 0, 0]] - kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) - kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) - self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False) - - self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False) - - - def forward(self, x): - x_list = [] - for i in range(x.shape[1]): - x_i = x[:, i] - x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1) - x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1) - x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6) - x_list.append(x_i) - - x = torch.cat(x_list, dim = 1) - - return x - - -#################### -# Generator -#################### - -class SPSRNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \ - act_type='leakyrelu', mode='CNA', upsample_mode='upconv'): - super(SPSRNet, self).__init__() - - n_upscale = int(math.log(upscale, 2)) - - if upscale == 3: - n_upscale = 1 - - fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) - rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ - norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)] - - LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) - - if upsample_mode == 'upconv': - upsample_block = B.upconv_blcok - elif upsample_mode == 'pixelshuffle': - upsample_block = B.pixelshuffle_block - else: - raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) - if upscale == 3: - upsampler = upsample_block(nf, nf, 3, act_type=act_type) - else: - upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] - - self.HR_conv0_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) - self.HR_conv1_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None) - - self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\ - *upsampler, self.HR_conv0_new) - - self.get_g_nopadding = Get_gradient_nopadding() - - self.b_fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) - - self.b_concat_1 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) - self.b_block_1 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ - norm_type=norm_type, act_type=act_type, mode='CNA') - - - self.b_concat_2 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) - self.b_block_2 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ - norm_type=norm_type, act_type=act_type, mode='CNA') - - - self.b_concat_3 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) - self.b_block_3 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ - norm_type=norm_type, act_type=act_type, mode='CNA') - - - self.b_concat_4 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) - self.b_block_4 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ - norm_type=norm_type, act_type=act_type, mode='CNA') - - self.b_LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) - - if upsample_mode == 'upconv': - upsample_block = B.upconv_blcok - elif upsample_mode == 'pixelshuffle': - upsample_block = B.pixelshuffle_block - else: - raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) - if upscale == 3: - b_upsampler = upsample_block(nf, nf, 3, act_type=act_type) - else: - b_upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] - - b_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) - b_HR_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None) - - self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1) - - self.conv_w = B.conv_block(nf, out_nc, kernel_size=1, norm_type=None, act_type=None) - - self.f_concat = B.conv_block(nf*2, nf, kernel_size=3, norm_type=None, act_type=None) - - self.f_block = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ - norm_type=norm_type, act_type=act_type, mode='CNA') - - self.f_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) - self.f_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None) - - - def forward(self, x): - - x_grad = self.get_g_nopadding(x) - x = self.model[0](x) - - x, block_list = self.model[1](x) - - x_ori = x - for i in range(5): - x = block_list[i](x) - x_fea1 = x - - for i in range(5): - x = block_list[i+5](x) - x_fea2 = x - - for i in range(5): - x = block_list[i+10](x) - x_fea3 = x - - for i in range(5): - x = block_list[i+15](x) - x_fea4 = x - - x = block_list[20:](x) - #short cut - x = x_ori+x - x= self.model[2:](x) - x = self.HR_conv1_new(x) - - x_b_fea = self.b_fea_conv(x_grad) - x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1) - - x_cat_1 = self.b_block_1(x_cat_1) - x_cat_1 = self.b_concat_1(x_cat_1) - - x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1) - - x_cat_2 = self.b_block_2(x_cat_2) - x_cat_2 = self.b_concat_2(x_cat_2) - - x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1) - - x_cat_3 = self.b_block_3(x_cat_3) - x_cat_3 = self.b_concat_3(x_cat_3) - - x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1) - - x_cat_4 = self.b_block_4(x_cat_4) - x_cat_4 = self.b_concat_4(x_cat_4) - - x_cat_4 = self.b_LR_conv(x_cat_4) - - #short cut - x_cat_4 = x_cat_4+x_b_fea - x_branch = self.b_module(x_cat_4) - - x_out_branch = self.conv_w(x_branch) - ######## - x_branch_d = x_branch - x_f_cat = torch.cat([x_branch_d, x], dim=1) - x_f_cat = self.f_block(x_f_cat) - x_out = self.f_concat(x_f_cat) - x_out = self.f_HR_conv0(x_out) - x_out = self.f_HR_conv1(x_out) - - ######### - return x_out_branch, x_out, x_grad - - -#################### -# Discriminator -#################### - - -# VGG style Discriminator with input size 128*128 -class Discriminator_VGG_128(nn.Module): - def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): - super(Discriminator_VGG_128, self).__init__() - # features - # hxw, c - # 128, 64 - conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ - mode=mode) - conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 64, 64 - conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 32, 128 - conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 16, 256 - conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 8, 512 - conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 4, 512 - self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\ - conv9) - - # classifier - self.classifier = nn.Sequential( - nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) - - def forward(self, x): - x = self.features(x) - x = x.view(x.size(0), -1) - x = self.classifier(x) - return x - - -#################### -# Perceptual Network -#################### - -class VGGFeatureExtractor(nn.Module): - def __init__(self, - feature_layer=34, - use_bn=False, - use_input_norm=True, - device=torch.device('cpu')): - super(VGGFeatureExtractor, self).__init__() - if use_bn: - model = torchvision.models.vgg19_bn(pretrained=True) - else: - model = torchvision.models.vgg19(pretrained=True) - self.use_input_norm = use_input_norm - if self.use_input_norm: - mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) - # [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1] - std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) - # [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1] - self.register_buffer('mean', mean) - self.register_buffer('std', std) - self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) - # No need to BP to variable - for k, v in self.features.named_parameters(): - v.requires_grad = False - - def forward(self, x): - if self.use_input_norm: - x = (x - self.mean) / self.std - output = self.features(x) - return output - - -class ResNet101FeatureExtractor(nn.Module): - def __init__(self, use_input_norm=True, device=torch.device('cpu')): - super(ResNet101FeatureExtractor, self).__init__() - model = torchvision.models.resnet101(pretrained=True) - self.use_input_norm = use_input_norm - if self.use_input_norm: - mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) - # [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1] - std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) - # [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1] - self.register_buffer('mean', mean) - self.register_buffer('std', std) - self.features = nn.Sequential(*list(model.children())[:8]) - # No need to BP to variable - for k, v in self.features.named_parameters(): - v.requires_grad = False - - def forward(self, x): - if self.use_input_norm: - x = (x - self.mean) / self.std - output = self.features(x) - return output - - -class MINCNet(nn.Module): - def __init__(self): - super(MINCNet, self).__init__() - self.ReLU = nn.ReLU(True) - self.conv11 = nn.Conv2d(3, 64, 3, 1, 1) - self.conv12 = nn.Conv2d(64, 64, 3, 1, 1) - self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True) - self.conv21 = nn.Conv2d(64, 128, 3, 1, 1) - self.conv22 = nn.Conv2d(128, 128, 3, 1, 1) - self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True) - self.conv31 = nn.Conv2d(128, 256, 3, 1, 1) - self.conv32 = nn.Conv2d(256, 256, 3, 1, 1) - self.conv33 = nn.Conv2d(256, 256, 3, 1, 1) - self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True) - self.conv41 = nn.Conv2d(256, 512, 3, 1, 1) - self.conv42 = nn.Conv2d(512, 512, 3, 1, 1) - self.conv43 = nn.Conv2d(512, 512, 3, 1, 1) - self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True) - self.conv51 = nn.Conv2d(512, 512, 3, 1, 1) - self.conv52 = nn.Conv2d(512, 512, 3, 1, 1) - self.conv53 = nn.Conv2d(512, 512, 3, 1, 1) - - def forward(self, x): - out = self.ReLU(self.conv11(x)) - out = self.ReLU(self.conv12(out)) - out = self.maxpool1(out) - out = self.ReLU(self.conv21(out)) - out = self.ReLU(self.conv22(out)) - out = self.maxpool2(out) - out = self.ReLU(self.conv31(out)) - out = self.ReLU(self.conv32(out)) - out = self.ReLU(self.conv33(out)) - out = self.maxpool3(out) - out = self.ReLU(self.conv41(out)) - out = self.ReLU(self.conv42(out)) - out = self.ReLU(self.conv43(out)) - out = self.maxpool4(out) - out = self.ReLU(self.conv51(out)) - out = self.ReLU(self.conv52(out)) - out = self.conv53(out) - return out - - -class MINCFeatureExtractor(nn.Module): - def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \ - device=torch.device('cpu')): - super(MINCFeatureExtractor, self).__init__() - - self.features = MINCNet() - self.features.load_state_dict( - torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True) - self.features.eval() - # No need to BP to variable - for k, v in self.features.named_parameters(): - v.requires_grad = False - - def forward(self, x): - output = self.features(x) - return output diff --git a/codes/models/SPSR_modules/loss.py b/codes/models/SPSR_modules/loss.py deleted file mode 100644 index 10dacddf..00000000 --- a/codes/models/SPSR_modules/loss.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -import torch.nn as nn - - -# Define GAN loss: [vanilla | lsgan] -class GANLoss(nn.Module): - def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): - super(GANLoss, self).__init__() - self.gan_type = gan_type.lower() - self.real_label_val = real_label_val - self.fake_label_val = fake_label_val - - if self.gan_type == 'vanilla': - self.loss = nn.BCEWithLogitsLoss() - elif self.gan_type == 'lsgan': - self.loss = nn.MSELoss() - else: - raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) - - def get_target_label(self, input, target_is_real): - if target_is_real: - return torch.empty_like(input).fill_(self.real_label_val) - else: - return torch.empty_like(input).fill_(self.fake_label_val) - - def forward(self, input, target_is_real): - target_label = self.get_target_label(input, target_is_real) - loss = self.loss(input, target_label) - return loss diff --git a/codes/models/SPSR_modules/sampler.py b/codes/models/SPSR_modules/sampler.py deleted file mode 100644 index 1e817667..00000000 --- a/codes/models/SPSR_modules/sampler.py +++ /dev/null @@ -1,94 +0,0 @@ -import random -import torch -import numpy as np - -def _get_random_crop_indices(crop_region, crop_size): - ''' - crop_region: (strat_y, end_y, start_x, end_x) - crop_size: (y, x) - ''' - region_size = (crop_region[1] - crop_region[0], crop_region[3] - crop_region[2]) - if region_size[0] < crop_size[0] or region_size[1] < crop_size[1]: - print(region_size, crop_size) - assert region_size[0] >= crop_size[0] and region_size[1] >= crop_size[1] - if region_size[0] == crop_size[0]: - start_y = crop_region[0] - else: - start_y = random.choice(range(crop_region[0], crop_region[1] - crop_size[0])) - if region_size[1] == crop_size[1]: - start_x = crop_region[2] - else: - start_x = random.choice(range(crop_region[2], crop_region[3] - crop_size[1])) - return start_y, start_y + crop_size[0], start_x, start_x + crop_size[1] - -def _get_adaptive_crop_indices(crop_region, crop_size, num_candidate, dist_map, min_diff=False): - candidates = [_get_random_crop_indices(crop_region, crop_size) for _ in range(num_candidate)] - max_choice = candidates[0] - min_choice = candidates[0] - max_dist = 0 - min_dist = np.infty - with torch.no_grad(): - for c in candidates: - start_y, end_y, start_x, end_x = c - dist = torch.sum(dist_map[start_y: end_y, start_x: end_x]) - if dist > max_dist: - max_dist = dist - max_choice = c - if dist < min_dist: - min_dist = dist - min_choice = c - if min_diff: - return min_choice - else: - return max_choice - -def get_split_list(divisor, dividend): - split_list = [dividend // divisor for _ in range(divisor - 1)] - split_list.append(dividend - (dividend // divisor) * (divisor - 1)) - return split_list - -def random_sampler(pic_size, crop_dict): - crop_region = (0, pic_size[0], 0, pic_size[1]) - crop_res_dict = {} - for k, v in crop_dict.items(): - crop_size = (int(k), int(k)) - crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)] - return crop_res_dict - -def region_sampler(crop_region, crop_dict): - crop_res_dict = {} - for k, v in crop_dict.items(): - crop_size = (int(k), int(k)) - crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)] - return crop_res_dict - -def adaptive_sampler(pic_size, crop_dict, num_candidate_dict, dist_map, min_diff=False): - crop_region = (0, pic_size[0], 0, pic_size[1]) - crop_res_dict = {} - for k, v in crop_dict.items(): - crop_size = (int(k), int(k)) - crop_res_dict[k] = [_get_adaptive_crop_indices(crop_region, crop_size, num_candidate_dict[k], dist_map, min_diff) for _ in range(v)] - return crop_res_dict - -# TODO more flexible -def pyramid_sampler(pic_size, crop_dict): - crop_res_dict = {} - sorted_key = list(crop_dict.keys()) - sorted_key.sort(key=lambda x: int(x), reverse=True) - k = sorted_key[0] - crop_size = (int(k), int(k)) - crop_region = (0, pic_size[0], 0, pic_size[1]) - crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(crop_dict[k])] - - for i in range(1, len(sorted_key)): - crop_res_dict[sorted_key[i]] = [] - afore_num = crop_dict[sorted_key[i-1]] - new_num = crop_dict[sorted_key[i]] - split_list = get_split_list(afore_num, new_num) - crop_size = (int(sorted_key[i]), int(sorted_key[i])) - for j in range(len(split_list)): - crop_region = crop_res_dict[sorted_key[i-1]][j] - crop_res_dict[sorted_key[i]].extend([_get_random_crop_indices(crop_region, crop_size) for _ in range(split_list[j])]) - - return crop_res_dict - diff --git a/codes/models/SPSR_networks.py b/codes/models/SPSR_networks.py deleted file mode 100644 index 575a378d..00000000 --- a/codes/models/SPSR_networks.py +++ /dev/null @@ -1,161 +0,0 @@ -import functools -import logging -import torch -import torch.nn as nn -from torch.nn import init - -import models.SPSR_modules.architecture as arch -logger = logging.getLogger('base') -#################### -# initialize -#################### - - -def weights_init_normal(m, std=0.02): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - init.normal_(m.weight.data, 0.0, std) - if m.bias is not None: - m.bias.data.zero_() - elif classname.find('Linear') != -1: - init.normal_(m.weight.data, 0.0, std) - if m.bias is not None: - m.bias.data.zero_() - elif classname.find('BatchNorm2d') != -1: - init.normal_(m.weight.data, 1.0, std) # BN also uses norm - init.constant_(m.bias.data, 0.0) - - -def weights_init_kaiming(m, scale=1): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') - m.weight.data *= scale - if m.bias is not None: - m.bias.data.zero_() - elif classname.find('Linear') != -1: - init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') - m.weight.data *= scale - if m.bias is not None: - m.bias.data.zero_() - elif classname.find('BatchNorm2d') != -1: - if m.affine != False: - - init.constant_(m.weight.data, 1.0) - init.constant_(m.bias.data, 0.0) - - -def weights_init_orthogonal(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - init.orthogonal_(m.weight.data, gain=1) - if m.bias is not None: - m.bias.data.zero_() - elif classname.find('Linear') != -1: - init.orthogonal_(m.weight.data, gain=1) - if m.bias is not None: - m.bias.data.zero_() - elif classname.find('BatchNorm2d') != -1: - init.constant_(m.weight.data, 1.0) - init.constant_(m.bias.data, 0.0) - - -def init_weights(net, init_type='kaiming', scale=1, std=0.02): - # scale for 'kaiming', std for 'normal'. - if init_type == 'normal': - weights_init_normal_ = functools.partial(weights_init_normal, std=std) - net.apply(weights_init_normal_) - elif init_type == 'kaiming': - weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale) - net.apply(weights_init_kaiming_) - elif init_type == 'orthogonal': - net.apply(weights_init_orthogonal) - else: - raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type)) - - -#################### -# define network -#################### - - -# Generator -def define_G(opt, device=None): - opt_net = opt['network_G'] - which_model = opt_net['which_model_G'] - - if which_model == 'spsr_net': - netG = arch.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], - nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], - act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') - else: - raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) - - if opt['is_train']: - init_weights(netG, init_type='kaiming', scale=0.1) - - return netG - - - -# Discriminator -def define_D(opt): - opt_net = opt['network_D'] - which_model = opt_net['which_model_D'] - - if which_model == 'discriminator_vgg_128': - netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ - norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) - - elif which_model == 'discriminator_vgg_96': - netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ - norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) - elif which_model == 'discriminator_vgg_192': - netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ - norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) - elif which_model == 'discriminator_vgg_128_SN': - netD = arch.Discriminator_VGG_128_SN() - else: - raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) - - init_weights(netD, init_type='kaiming', scale=1) - - return netD - -def define_D_grad(opt): - opt_net = opt['network_D'] - which_model = opt_net['which_model_D'] - - if which_model == 'discriminator_vgg_128': - netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ - norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) - - elif which_model == 'discriminator_vgg_96': - netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ - norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) - elif which_model == 'discriminator_vgg_192': - netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ - norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) - elif which_model == 'discriminator_vgg_128_SN': - netD = arch.Discriminator_VGG_128_SN() - else: - raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) - - - init_weights(netD, init_type='kaiming', scale=1) - - return netD - - -def define_F(opt, use_bn=False): - device = torch.device('cuda') - # pytorch pretrained VGG19-54, before ReLU. - if use_bn: - feature_layer = 49 - else: - feature_layer = 34 - netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \ - use_input_norm=True, device=device) - - netF.eval() - return netF diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 85000f40..89da1eb7 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -9,6 +9,7 @@ from models.base_model import BaseModel from models.loss import GANLoss, FDPLLoss from apex import amp from data.weight_scheduler import get_scheduler_for_opt +from .archs.SPSR_arch import ImageGradient, ImageGradientNoPadding import torch.nn.functional as F import glob import random @@ -27,11 +28,18 @@ class SRGANModel(BaseModel): else: self.rank = -1 # non dist training train_opt = opt['train'] + self.spsr_enabled = 'spsr' in opt['model'] + + # Only pixgan and gan are currently supported in spsr_mode + if self.spsr_enabled: + assert train_opt['gan_type'] == 'pixgan' or train_opt['gan_type'] == 'gan' # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if self.is_train: self.netD = networks.define_D(opt).to(self.device) + if self.spsr_enabled: + self.netD_grad = networks.define_D(opt).to(self.device) # D_grad if 'network_C' in opt.keys(): self.netC = networks.define_G(opt, net_key='network_C').to(self.device) @@ -73,6 +81,33 @@ class SRGANModel(BaseModel): else: self.fdpl_enabled = False + if self.spsr_enabled: + spsr_opt = train_opt['spsr'] + self.branch_pretrain = spsr_opt['branch_pretrain'] if spsr_opt['branch_pretrain'] else 0 + self.branch_init_iters = spsr_opt['branch_init_iters'] if spsr_opt['branch_init_iters'] else 1 + if spsr_opt['gradient_pixel_weight'] > 0: + self.cri_pix_grad = nn.MSELoss().to(self.device) + self.l_pix_grad_w = spsr_opt['gradient_pixel_weight'] + else: + self.cri_pix_grad = None + if spsr_opt['gradient_gan_weight'] > 0: + self.cri_grad_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) + self.l_gan_grad_w = spsr_opt['gradient_gan_weight'] + else: + self.cri_grad_gan = None + if spsr_opt['pixel_branch_weight'] > 0: + l_pix_type = spsr_opt['pixel_branch_criterion'] + if l_pix_type == 'l1': + self.cri_pix_branch = nn.L1Loss().to(self.device) + elif l_pix_type == 'l2': + self.cri_pix_branch = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) + self.l_pix_branch_w = spsr_opt['pixel_branch_weight'] + else: + logger.info('Remove G_grad pixel loss.') + self.cri_pix_branch = None + # G feature loss if train_opt['feature_weight'] and train_opt['feature_weight'] > 0: # For backwards compatibility, use a scheduler definition instead. Remove this at some point. @@ -139,7 +174,7 @@ class SRGANModel(BaseModel): self.corruptor_usage_prob = train_opt['corruptor_usage_probability'] if train_opt['corruptor_usage_probability'] else .5 # optimizers - # G + # G optimizer wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] if train_opt['lr_scheme'] == 'ProgressiveMultiStepLR': @@ -155,6 +190,7 @@ class SRGANModel(BaseModel): weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) + # D optimizer optim_params = [] for k, v in self.netD.named_parameters(): # can optimize for a part of the model if v.requires_grad: @@ -162,16 +198,40 @@ class SRGANModel(BaseModel): else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) - # D wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(optim_params, lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) - # AMP - [self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \ - amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3) + if self.spsr_enabled: + # D_grad optimizer + optim_params = [] + for k, v in self.netD_grad.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + # D + wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 + self.optimizer_D_grad = torch.optim.Adam(optim_params, lr=train_opt['lr_D'], + weight_decay=wd_D, + betas=(train_opt['beta1_D'], train_opt['beta2_D'])) + self.optimizers.append(self.optimizer_D_grad) + + if self.spsr_enabled: + self.get_grad = ImageGradient().to(self.device) + self.get_grad_nopadding = ImageGradientNoPadding().to(self.device) + [self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding], \ + [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \ + amp.initialize([self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding], + [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad], + opt_level=self.amp_level, num_losses=3) + else: + # AMP + [self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \ + amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3) # DataParallel if opt['dist']: @@ -188,6 +248,8 @@ class SRGANModel(BaseModel): self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() + if self.spsr_enabled: + self.netD_grad.train() # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': @@ -208,6 +270,10 @@ class SRGANModel(BaseModel): self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D, train_opt['disc_lr_steps'], [0], train_opt['lr_gamma'])) + if self.spsr_enabled: + self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D_grad, train_opt['disc_lr_steps'], + [0], + train_opt['lr_gamma'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( @@ -284,18 +350,22 @@ class SRGANModel(BaseModel): # G for p in self.netD.parameters(): p.requires_grad = False - - if step >= self.D_init_iters: - self.optimizer_G.zero_grad() + if self.spsr_enabled: + for p in self.netD_grad.parameters(): + p.requires_grad = False self.swapout_D(step) self.swapout_G(step) # Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason. if step % self.D_update_ratio == 0 and step >= self.D_init_iters: - for p in self.netG.parameters(): - if p.dtype != torch.int64 and p.dtype != torch.bool: - p.requires_grad = True + if self.spsr_enabled and self.branch_pretrain and step < self.branch_init_iters: + for k, v in self.netG.named_parameters(): + v.requires_grad = '_branch_pretrain' in k + else: + for p in self.netG.parameters(): + if p.dtype != torch.int64 and p.dtype != torch.bool: + p.requires_grad = True else: for p in self.netG.parameters(): p.requires_grad = False @@ -310,17 +380,32 @@ class SRGANModel(BaseModel): print("Misc setup %f" % (time() - _t,)) _t = time() + if step >= self.D_init_iters: + self.optimizer_G.zero_grad() self.fake_GenOut = [] self.fea_GenOut = [] self.fake_H = [] + self.spsr_grad_GenOut = [] var_ref_skips = [] for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix): - if random.random() > self.gan_lq_img_use_prob: - fea_GenOut, fake_GenOut = self.netG(var_L) + if self.spsr_enabled: + # SPSR models have outputs from three different branches. + fake_H_branch, fake_GenOut, grad_LR = self.netG(var_L) + fea_GenOut = fake_GenOut using_gan_img = False + # Get image gradients for later use. + fake_H_grad = self.get_grad(fake_GenOut) + var_H_grad = self.get_grad(var_H) + var_ref_grad = self.get_grad(var_ref) + var_H_grad_nopadding = self.get_grad_nopadding(var_H) + self.spsr_grad_GenOut.append(grad_LR) else: - fea_GenOut, fake_GenOut = self.netG(var_LGAN) - using_gan_img = True + if random.random() > self.gan_lq_img_use_prob: + fea_GenOut, fake_GenOut = self.netG(var_L) + using_gan_img = False + else: + fea_GenOut, fake_GenOut = self.netG(var_LGAN) + using_gan_img = True if _profile: print("Gen forward %f" % (time() - _t,)) @@ -339,6 +424,13 @@ class SRGANModel(BaseModel): l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix) l_g_pix_log = l_g_pix / self.l_pix_w l_g_total += l_g_pix + if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss + l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad) + l_g_total += l_g_pix_grad + if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss + l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(fake_H_branch, + var_H_grad_nopadding) + l_g_total += l_g_pix_grad_branch if self.fdpl_enabled and not using_gan_img: l_g_fdpl = self.cri_fdpl(fea_GenOut, pix) l_g_total += l_g_fdpl * self.fdpl_weight @@ -370,6 +462,7 @@ class SRGANModel(BaseModel): l_g_fix_disc = l_g_fix_disc + weight * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fix_disc + if self.l_gan_w > 0: if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']: pred_g_fake = self.netD(fake_GenOut) @@ -383,6 +476,14 @@ class SRGANModel(BaseModel): l_g_gan_log = l_g_gan / self.l_gan_w l_g_total += l_g_gan + if self.spsr_enabled and self.cri_grad_gan: # grad G gan + cls loss + pred_g_fake_grad = self.netD_grad(fake_H_grad) + pred_d_real_grad = self.netD_grad(var_ref_grad).detach() + + l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) + + self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2 + l_g_total += l_g_gan_grad + # Scale the loss down by the batch factor. l_g_total_log = l_g_total l_g_total = l_g_total / self.mega_batch_factor @@ -418,8 +519,10 @@ class SRGANModel(BaseModel): gen_input = var_LGAN # Re-compute generator outputs (post-update). with torch.no_grad(): - _, fake_H = self.netG(gen_input) - # The following line detaches all generator outputs that are not None. + if self.spsr_enabled: + _, fake_H, _ = self.netG(gen_input) + else: + _, fake_H = self.netG(gen_input) fake_H = fake_H.detach() if _profile: @@ -546,11 +649,36 @@ class SRGANModel(BaseModel): self.fake_H.append(fake_H.detach()) self.optimizer_D.step() - if _profile: print("Disc step %f" % (time() - _t,)) _t = time() + # D_grad. + if self.spsr_enabled and self.cri_grad_gan and step >= self.G_warmup: + for p in self.netD_grad.parameters(): + p.requires_grad = True + self.optimizer_D_grad.zero_grad() + + for var_ref, fake_H in zip(self.var_ref, self.fake_H): + fake_H_grad = self.get_grad(fake_H) + var_ref_grad = self.get_grad(var_ref) + pred_d_real_grad = self.netD_grad(var_ref_grad) + pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G + if self.opt['train']['gan_type'] == 'gan': + l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True) + l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) + elif self.opt['train']['gan_type'] == 'pixgan': + real = torch.ones_like(pred_d_real_grad) + fake = torch.zeros_like(pred_d_fake_grad) + l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), real) + l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), fake) + l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2 + l_d_total_grad /= self.mega_batch_factor + with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled: + l_d_total_grad_scaled.backward() + self.optimizer_D_grad.step() + + # Log sample images from first microbatch. if step % self.img_debug_steps == 0: sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp") @@ -562,6 +690,8 @@ class SRGANModel(BaseModel): os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True) + if self.spsr_enabled: + os.makedirs(os.path.join(sample_save_path, "gen_grad"), exist_ok=True) # fed_LQ is not chunked. for i in range(self.mega_batch_factor): @@ -570,6 +700,8 @@ class SRGANModel(BaseModel): utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i))) + if self.spsr_enabled: + utils.save_image(self.spsr_grad_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_grad", "%05i_%02i.png" % (step, i))) if self.l_gan_w > 0 and step >= self.G_warmup and 'pixgan' in self.opt['train']['gan_type']: utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i))) @@ -594,11 +726,19 @@ class SRGANModel(BaseModel): self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor) self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor) self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor) + if self.spsr_enabled: + if self.cri_pix_branch: + self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.item()) if self.l_gan_w > 0 and step >= self.G_warmup: self.add_log_entry('l_d_real', l_d_real_log.item()) self.add_log_entry('l_d_fake', l_d_fake_log.item()) self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real)) + if self.spsr_enabled: + self.add_log_entry('l_d_real_grad', l_d_real_grad.item()) + self.add_log_entry('l_d_fake_grad', l_d_fake_grad.item()) + self.add_log_entry('D_fake', torch.mean(pred_d_fake_grad.detach())) + self.add_log_entry('D_diff', torch.mean(pred_d_fake_grad) - torch.mean(pred_d_real_grad)) # Log learning rates. for i, pg in enumerate(self.optimizer_G.param_groups): @@ -685,7 +825,16 @@ class SRGANModel(BaseModel): def test(self): self.netG.eval() with torch.no_grad(): - self.fake_GenOut = [self.netG(self.var_L[0])] + if self.spsr_enabled: + self.fake_H_branch = [] + self.fake_GenOut = [] + self.grad_LR = [] + fake_H_branch, fake_GenOut, grad_LR = self.netG(self.var_L[0]) + self.fake_H_branch.append(fake_H_branch) + self.fake_GenOut.append(fake_GenOut) + self.grad_LR.append(grad_LR) + else: + self.fake_GenOut = [self.netG(self.var_L[0])] self.netG.train() # Fetches a summary of the log. @@ -713,6 +862,9 @@ class SRGANModel(BaseModel): out_dict['rlt'] = gen_batch.detach().float().cpu() if need_GT: out_dict['GT'] = self.var_H[0].detach().float().cpu() + if self.spsr_enabled: + out_dict['SR_branch'] = self.fake_H_branch[0].float().cpu() + out_dict['LR_grad'] = self.grad_LR[0].float().cpu() return out_dict def print_network(self): @@ -762,6 +914,11 @@ class SRGANModel(BaseModel): if self.opt['is_train'] and load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) + if self.spsr_enabled: + load_path_D_grad = self.opt['path']['pretrain_model_D_grad'] + if self.opt['is_train'] and load_path_D_grad is not None: + logger.info('Loading pretrained model for D_grad [{:s}] ...'.format(load_path_D_grad)) + self.load_network(load_path_D_grad, self.netD_grad) def load_random_corruptor(self): if self.netC is None: @@ -774,3 +931,4 @@ class SRGANModel(BaseModel): def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step) + self.save_network(self.netD_grad, 'D_grad', iter_step) diff --git a/codes/models/__init__.py b/codes/models/__init__.py index 1697b87a..4cb3264a 100644 --- a/codes/models/__init__.py +++ b/codes/models/__init__.py @@ -7,11 +7,11 @@ def create_model(opt): # image restoration if model == 'sr': # PSNR-oriented super resolution from .SR_model import SRModel as M - elif model == 'srgan' or model == 'corruptgan': # GAN-based super resolution(SRGAN / ESRGAN), or corruption use same logic + elif model == 'srgan' or model == 'corruptgan' or model == 'spsrgan': from .SRGAN_model import SRGANModel as M elif model == 'feat': from .feature_model import FeatureModel as M - if model == 'spsr': + elif model == 'spsr': from .SPSR_model import SPSRModel as M else: raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py new file mode 100644 index 00000000..90982e9f --- /dev/null +++ b/codes/models/archs/SPSR_arch.py @@ -0,0 +1,226 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.archs import SPSR_util as B +from .RRDBNet_arch import RRDB + + +class ImageGradient(nn.Module): + def __init__(self): + super(ImageGradient, self).__init__() + kernel_v = [[0, -1, 0], + [0, 0, 0], + [0, 1, 0]] + kernel_h = [[0, 0, 0], + [-1, 0, 1], + [0, 0, 0]] + kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) + kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) + self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda() + self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda() + + def forward(self, x): + x0 = x[:, 0] + x1 = x[:, 1] + x2 = x[:, 2] + x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=2) + x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=2) + + x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=2) + x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=2) + + x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=2) + x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=2) + + x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6) + x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6) + x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6) + + x = torch.cat([x0, x1, x2], dim=1) + return x + + +class ImageGradientNoPadding(nn.Module): + def __init__(self): + super(ImageGradientNoPadding, self).__init__() + kernel_v = [[0, -1, 0], + [0, 0, 0], + [0, 1, 0]] + kernel_h = [[0, 0, 0], + [-1, 0, 1], + [0, 0, 0]] + kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) + kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) + self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False) + + self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False) + + + def forward(self, x): + x_list = [] + for i in range(x.shape[1]): + x_i = x[:, i] + x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1) + x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1) + x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6) + x_list.append(x_i) + + x = torch.cat(x_list, dim = 1) + + return x + + +#################### +# Generator +#################### + +class SPSRNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \ + act_type='leakyrelu', mode='CNA', upsample_mode='upconv'): + super(SPSRNet, self).__init__() + + n_upscale = int(math.log(upscale, 2)) + + if upscale == 3: + n_upscale = 1 + + fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) + rb_blocks = [RRDB(nf, gc=32) for _ in range(nb)] + + LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) + + if upsample_mode == 'upconv': + upsample_block = B.upconv_blcok + elif upsample_mode == 'pixelshuffle': + upsample_block = B.pixelshuffle_block + else: + raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) + if upscale == 3: + upsampler = upsample_block(nf, nf, 3, act_type=act_type) + else: + upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] + + self.HR_conv0_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) + self.HR_conv1_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None) + + self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\ + *upsampler, self.HR_conv0_new) + + self.get_g_nopadding = ImageGradientNoPadding() + + self.b_fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) + + self.b_concat_1 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) + self.b_block_1 = RRDB(nf*2, gc=32) + + + self.b_concat_2 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) + self.b_block_2 = RRDB(nf*2, gc=32) + + + self.b_concat_3 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) + self.b_block_3 = RRDB(nf*2, gc=32) + + + self.b_concat_4 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) + self.b_block_4 = RRDB(nf*2, gc=32) + + self.b_LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) + + if upsample_mode == 'upconv': + upsample_block = B.upconv_blcok + elif upsample_mode == 'pixelshuffle': + upsample_block = B.pixelshuffle_block + else: + raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) + if upscale == 3: + b_upsampler = upsample_block(nf, nf, 3, act_type=act_type) + else: + b_upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] + + b_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) + b_HR_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None) + + self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1) + + self.conv_w = B.conv_block(nf, out_nc, kernel_size=1, norm_type=None, act_type=None) + + # Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest. + self._branch_pretrain_concat = B.conv_block(nf*2, nf, kernel_size=3, norm_type=None, act_type=None) + + self._branch_pretrain_block = RRDB(nf*2, gc=32) + + self._branch_pretrain_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) + self._branch_pretrain_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None) + + + def forward(self, x): + + x_grad = self.get_g_nopadding(x) + x = self.model[0](x) + + x, block_list = self.model[1](x) + + x_ori = x + for i in range(5): + x = block_list[i](x) + x_fea1 = x + + for i in range(5): + x = block_list[i+5](x) + x_fea2 = x + + for i in range(5): + x = block_list[i+10](x) + x_fea3 = x + + for i in range(5): + x = block_list[i+15](x) + x_fea4 = x + + x = block_list[20:](x) + #short cut + x = x_ori+x + x= self.model[2:](x) + x = self.HR_conv1_new(x) + + x_b_fea = self.b_fea_conv(x_grad) + x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1) + + x_cat_1 = self.b_block_1(x_cat_1) + x_cat_1 = self.b_concat_1(x_cat_1) + + x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1) + + x_cat_2 = self.b_block_2(x_cat_2) + x_cat_2 = self.b_concat_2(x_cat_2) + + x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1) + + x_cat_3 = self.b_block_3(x_cat_3) + x_cat_3 = self.b_concat_3(x_cat_3) + + x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1) + + x_cat_4 = self.b_block_4(x_cat_4) + x_cat_4 = self.b_concat_4(x_cat_4) + + x_cat_4 = self.b_LR_conv(x_cat_4) + + #short cut + x_cat_4 = x_cat_4+x_b_fea + x_branch = self.b_module(x_cat_4) + + x_out_branch = self.conv_w(x_branch) + ######## + x_branch_d = x_branch + x__branch_pretrain_cat = torch.cat([x_branch_d, x], dim=1) + x__branch_pretrain_cat = self._branch_pretrain_block(x__branch_pretrain_cat) + x_out = self._branch_pretrain_concat(x__branch_pretrain_cat) + x_out = self._branch_pretrain_HR_conv0(x_out) + x_out = self._branch_pretrain_HR_conv1(x_out) + + ######### + return x_out_branch, x_out, x_grad + diff --git a/codes/models/SPSR_modules/block.py b/codes/models/archs/SPSR_util.py similarity index 60% rename from codes/models/SPSR_modules/block.py rename to codes/models/archs/SPSR_util.py index dfaa2246..0485f5ae 100644 --- a/codes/models/SPSR_modules/block.py +++ b/codes/models/archs/SPSR_util.py @@ -5,8 +5,6 @@ import torch.nn as nn #################### # Basic blocks #################### - - def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1): # helper selecting activation # neg_slope: for leakyrelu and init of prelu @@ -134,101 +132,6 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias= return sequential(n, a, p, c) -#################### -# Useful blocks -#################### - -class ResNetBlock(nn.Module): - ''' - ResNet Block, 3-3 style - with extra residual scaling used in EDSR - (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17) - ''' - - def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \ - bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1): - super(ResNetBlock, self).__init__() - conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \ - norm_type, act_type, mode) - if mode == 'CNA': - act_type = None - if mode == 'CNAC': # Residual path: |-CNAC-| - act_type = None - norm_type = None - conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \ - norm_type, act_type, mode) - # if in_nc != out_nc: - # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \ - # None, None) - # print('Need a projecter in ResNetBlock.') - # else: - # self.project = lambda x:x - self.res = sequential(conv0, conv1) - self.res_scale = res_scale - - def forward(self, x): - res = self.res(x).mul(self.res_scale) - return x + res - - -class ResidualDenseBlock_5C(nn.Module): - ''' - Residual Dense Block - style: 5 convs - The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) - ''' - - def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ - norm_type=None, act_type='leakyrelu', mode='CNA'): - super(ResidualDenseBlock_5C, self).__init__() - # gc: growth channel, i.e. intermediate channels - self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \ - norm_type=norm_type, act_type=act_type, mode=mode) - self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \ - norm_type=norm_type, act_type=act_type, mode=mode) - self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \ - norm_type=norm_type, act_type=act_type, mode=mode) - self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \ - norm_type=norm_type, act_type=act_type, mode=mode) - if mode == 'CNA': - last_act = None - else: - last_act = act_type - self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \ - norm_type=norm_type, act_type=last_act, mode=mode) - - def forward(self, x): - x1 = self.conv1(x) - x2 = self.conv2(torch.cat((x, x1), 1)) - x3 = self.conv3(torch.cat((x, x1, x2), 1)) - x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return x5.mul(0.2) + x - - -class RRDB(nn.Module): - ''' - Residual in Residual Dense Block - (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) - ''' - - def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ - norm_type=None, act_type='leakyrelu', mode='CNA'): - super(RRDB, self).__init__() - self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \ - norm_type, act_type, mode) - self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \ - norm_type, act_type, mode) - self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \ - norm_type, act_type, mode) - - def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - return out.mul(0.2) + x - - #################### # Upsampler #################### diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index dd6bd88c..42fa3eb0 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -78,6 +78,65 @@ class Discriminator_VGG_128(nn.Module): return out +class Discriminator_VGG_128_GN(nn.Module): + # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. + def __init__(self, in_nc, nf, input_img_factor=1): + super(Discriminator_VGG_128_GN, self).__init__() + # [64, 128, 128] + self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) + self.bn0_1 = nn.GroupNorm(8, nf, affine=True) + # [64, 64, 64] + self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) + self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True) + self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) + self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True) + # [128, 32, 32] + self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) + self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True) + self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) + self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True) + # [256, 16, 16] + self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) + self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True) + self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True) + # [512, 8, 8] + self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) + self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True) + self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True) + final_nf = nf * 8 + + self.linear1 = nn.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100) + self.linear2 = nn.Linear(100, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.lrelu(self.conv0_0(x)) + fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) + + #fea = torch.cat([fea, skip_med], dim=1) + fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) + fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) + + #fea = torch.cat([fea, skip_lo], dim=1) + fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) + fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) + + fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) + fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) + + fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) + fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) + + fea = fea.contiguous().view(fea.size(0), -1) + fea = self.lrelu(self.linear1(fea)) + out = self.linear2(fea) + return out + class Discriminator_VGG_PixLoss(nn.Module): def __init__(self, in_nc, nf): super(Discriminator_VGG_PixLoss, self).__init__() diff --git a/codes/models/networks.py b/codes/models/networks.py index 82d50910..07aaa4c1 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -11,6 +11,8 @@ import models.archs.feature_arch as feature_arch import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch import models.archs.SRG1_arch as srg1 import models.archs.ProgressiveSrg_arch as psrg +import models.archs.SPSR_arch as spsr +import models.archs.arch_util as arch_util import functools from collections import OrderedDict @@ -97,6 +99,12 @@ def define_G(opt, net_key='network_G'): initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'], start_step=opt_net['start_step']) + elif which_model == 'spsr_net': + netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], + nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], + act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') + if opt['is_train']: + arch_util.initialize_weights(netG, scale=.1) # image corruption elif which_model == 'HighToLowResNet': @@ -119,6 +127,8 @@ def define_D_net(opt_net, img_sz=None): if which_model == 'discriminator_vgg_128': netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128, extra_conv=opt_net['extra_conv']) + elif which_model == 'discriminator_vgg_128_gn': + netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128) elif which_model == 'discriminator_resnet': netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) elif which_model == 'discriminator_resnet_passthrough': diff --git a/codes/options/options.py b/codes/options/options.py index 726a864a..77acf126 100644 --- a/codes/options/options.py +++ b/codes/options/options.py @@ -115,7 +115,11 @@ def check_resume(opt, resume_iter): opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], '{}_G.pth'.format(resume_iter)) logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) - if 'gan' in opt['model']: + if 'gan' in opt['model'] or 'spsr' in opt['model']: opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], '{}_D.pth'.format(resume_iter)) logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) + if 'spsr' in opt['model']: + opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'], + '{}_D_grad.pth'.format(resume_iter)) + logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad']) diff --git a/codes/train.py b/codes/train.py index 89d70ac7..1de1fda4 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_rrdb.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -215,7 +215,7 @@ def main(): logger.info(message) #### validation if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: - if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsr'] and rank <= 0: # image restoration validation + if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan'] and rank <= 0: # image restoration validation model.force_restore_swapout() val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size'] # does not support multi-GPU validation