diff --git a/codes/models/SPSR_model.py b/codes/models/SPSR_model.py new file mode 100644 index 00000000..9c89b556 --- /dev/null +++ b/codes/models/SPSR_model.py @@ -0,0 +1,462 @@ +import os +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn +from torch.optim import lr_scheduler + +import models.SPSR_networks as networks +from .base_model import BaseModel +from models.SPSR_modules.loss import GANLoss, GradientPenaltyLoss +logger = logging.getLogger('base') + +import torch.nn.functional as F + +class Get_gradient(nn.Module): + def __init__(self): + super(Get_gradient, self).__init__() + kernel_v = [[0, -1, 0], + [0, 0, 0], + [0, 1, 0]] + kernel_h = [[0, 0, 0], + [-1, 0, 1], + [0, 0, 0]] + kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) + kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) + self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda() + self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda() + + def forward(self, x): + x0 = x[:, 0] + x1 = x[:, 1] + x2 = x[:, 2] + x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=2) + x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=2) + + x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=2) + x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=2) + + x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=2) + x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=2) + + x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6) + x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6) + x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6) + + x = torch.cat([x0, x1, x2], dim=1) + return x + +class Get_gradient_nopadding(nn.Module): + def __init__(self): + super(Get_gradient_nopadding, self).__init__() + kernel_v = [[0, -1, 0], + [0, 0, 0], + [0, 1, 0]] + kernel_h = [[0, 0, 0], + [-1, 0, 1], + [0, 0, 0]] + kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) + kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) + self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda() + self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda() + + def forward(self, x): + x0 = x[:, 0] + x1 = x[:, 1] + x2 = x[:, 2] + x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding = 1) + x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding = 1) + + x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding = 1) + x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding = 1) + + x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding = 1) + x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding = 1) + + x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6) + x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6) + x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6) + + x = torch.cat([x0, x1, x2], dim=1) + return x + + +class SPSRModel(BaseModel): + def __init__(self, opt): + super(SPSRModel, self).__init__(opt) + train_opt = opt['train'] + + # define networks and load pretrained models + self.netG = networks.define_G(opt).to(self.device) # G + if self.is_train: + self.netD = networks.define_D(opt).to(self.device) # D + self.netD_grad = networks.define_D_grad(opt).to(self.device) # D_grad + self.netG.train() + self.netD.train() + self.netD_grad.train() + self.load() # load G and D if needed + + # define losses, optimizer and scheduler + if self.is_train: + # G pixel loss + if train_opt['pixel_weight'] > 0: + l_pix_type = train_opt['pixel_criterion'] + if l_pix_type == 'l1': + self.cri_pix = nn.L1Loss().to(self.device) + elif l_pix_type == 'l2': + self.cri_pix = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) + self.l_pix_w = train_opt['pixel_weight'] + else: + logger.info('Remove pixel loss.') + self.cri_pix = None + + # G feature loss + if train_opt['feature_weight'] > 0: + l_fea_type = train_opt['feature_criterion'] + if l_fea_type == 'l1': + self.cri_fea = nn.L1Loss().to(self.device) + elif l_fea_type == 'l2': + self.cri_fea = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) + self.l_fea_w = train_opt['feature_weight'] + else: + logger.info('Remove feature loss.') + self.cri_fea = None + if self.cri_fea: # load VGG perceptual loss + self.netF = networks.define_F(opt, use_bn=False).to(self.device) + + # GD gan loss + self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) + self.l_gan_w = train_opt['gan_weight'] + # D_update_ratio and D_init_iters are for WGAN + self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 + self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 + # Branch_init_iters + self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt['Branch_pretrain'] else 0 + self.Branch_init_iters = train_opt['Branch_init_iters'] if train_opt['Branch_init_iters'] else 1 + + if train_opt['gan_type'] == 'wgan-gp': + self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) + # gradient penalty loss + self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device) + self.l_gp_w = train_opt['gp_weigth'] + + # gradient_pixel_loss + if train_opt['gradient_pixel_weight'] > 0: + self.cri_pix_grad = nn.MSELoss().to(self.device) + self.l_pix_grad_w = train_opt['gradient_pixel_weight'] + else: + self.cri_pix_grad = None + + # gradient_gan_loss + if train_opt['gradient_gan_weight'] > 0: + self.cri_grad_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) + self.l_gan_grad_w = train_opt['gradient_gan_weight'] + else: + self.cri_grad_gan = None + + # G_grad pixel loss + if train_opt['pixel_branch_weight'] > 0: + l_pix_type = train_opt['pixel_branch_criterion'] + if l_pix_type == 'l1': + self.cri_pix_branch = nn.L1Loss().to(self.device) + elif l_pix_type == 'l2': + self.cri_pix_branch = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) + self.l_pix_branch_w = train_opt['pixel_branch_weight'] + else: + logger.info('Remove G_grad pixel loss.') + self.cri_pix_branch = None + + # optimizers + # G + wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 + + optim_params = [] + for k, v in self.netG.named_parameters(): # optimize part of the model + + if v.requires_grad: + optim_params.append(v) + else: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ + weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) + self.optimizers.append(self.optimizer_G) + + # D + wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ + weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) + + self.optimizers.append(self.optimizer_D) + + # D_grad + wd_D_grad = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 + self.optimizer_D_grad = torch.optim.Adam(self.netD_grad.parameters(), lr=train_opt['lr_D'], \ + weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) + + self.optimizers.append(self.optimizer_D_grad) + + # schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ + train_opt['lr_steps'], train_opt['lr_gamma'])) + else: + raise NotImplementedError('MultiStepLR learning rate scheme is enough.') + + self.log_dict = OrderedDict() + self.get_grad = Get_gradient() + self.get_grad_nopadding = Get_gradient_nopadding() + + def feed_data(self, data, need_HR=True): + # LR + self.var_L = data['LQ'].to(self.device) + + if need_HR: # train or val + self.var_H = data['GT'].to(self.device) + input_ref = data['ref'] if 'ref' in data else data['GT'] + self.var_ref = input_ref.to(self.device) + + + + def optimize_parameters(self, step): + # G + for p in self.netD.parameters(): + p.requires_grad = False + + for p in self.netD_grad.parameters(): + p.requires_grad = False + + + if(self.Branch_pretrain): + if(step < self.Branch_init_iters): + for k,v in self.netG.named_parameters(): + if 'f_' not in k : + v.requires_grad=False + else: + for k,v in self.netG.named_parameters(): + if 'f_' not in k : + v.requires_grad=True + + + self.optimizer_G.zero_grad() + + self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L) + + + self.fake_H_grad = self.get_grad(self.fake_H) + self.var_H_grad = self.get_grad(self.var_H) + self.var_ref_grad = self.get_grad(self.var_ref) + self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H) + + + l_g_total = 0 + if step % self.D_update_ratio == 0 and step > self.D_init_iters: + if self.cri_pix: # pixel loss + l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) + l_g_total += l_g_pix + if self.cri_fea: # feature loss + real_fea = self.netF(self.var_H).detach() + fake_fea = self.netF(self.fake_H) + l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) + l_g_total += l_g_fea + + if self.cri_pix_grad: #gradient pixel loss + l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(self.fake_H_grad, self.var_H_grad) + l_g_total += l_g_pix_grad + + + if self.cri_pix_branch: #branch pixel loss + l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(self.fake_H_branch, self.var_H_grad_nopadding) + l_g_total += l_g_pix_grad_branch + + + # G gan + cls loss + pred_g_fake = self.netD(self.fake_H) + pred_d_real = self.netD(self.var_ref).detach() + + l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 + l_g_total += l_g_gan + + # grad G gan + cls loss + + pred_g_fake_grad = self.netD_grad(self.fake_H_grad) + pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach() + + l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) + + self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2 + l_g_total += l_g_gan_grad + + + l_g_total.backward() + self.optimizer_G.step() + + + # D + for p in self.netD.parameters(): + p.requires_grad = True + + self.optimizer_D.zero_grad() + l_d_total = 0 + pred_d_real = self.netD(self.var_ref) + pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) + l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) + + l_d_total = (l_d_real + l_d_fake) / 2 + + if self.opt['train']['gan_type'] == 'wgan-gp': + batch_size = self.var_ref.size(0) + if self.random_pt.size(0) != batch_size: + self.random_pt.resize_(batch_size, 1, 1, 1) + self.random_pt.uniform_() # Draw random interpolation points + interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_ref + interp.requires_grad = True + interp_crit, _ = self.netD(interp) + l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit) + l_d_total += l_d_gp + + l_d_total.backward() + + self.optimizer_D.step() + + + for p in self.netD_grad.parameters(): + p.requires_grad = True + + self.optimizer_D_grad.zero_grad() + l_d_total_grad = 0 + + + pred_d_real_grad = self.netD_grad(self.var_ref_grad) + pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach()) # detach to avoid BP to G + + l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True) + l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) + + l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2 + + + l_d_total_grad.backward() + + self.optimizer_D_grad.step() + + # Log sample images from first microbatch. + if step % 50 == 0: + import torchvision.utils as utils + sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp") + os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True) + # fed_LQ is not chunked. + utils.save_image(self.var_H.cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,))) + utils.save_image(self.var_L.cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,))) + utils.save_image(self.fake_H.cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,))) + + + # set log + if step % self.D_update_ratio == 0 and step > self.D_init_iters: + # G + if self.cri_pix: + self.log_dict['l_g_pix'] = l_g_pix.item() + if self.cri_fea: + self.log_dict['l_g_fea'] = l_g_fea.item() + self.log_dict['l_g_gan'] = l_g_gan.item() + + if self.cri_pix_branch: #branch pixel loss + self.log_dict['l_g_pix_grad_branch'] = l_g_pix_grad_branch.item() + + # D + self.log_dict['l_d_real'] = l_d_real.item() + self.log_dict['l_d_fake'] = l_d_fake.item() + + # D_grad + self.log_dict['l_d_real_grad'] = l_d_real_grad.item() + self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item() + + if self.opt['train']['gan_type'] == 'wgan-gp': + self.log_dict['l_d_gp'] = l_d_gp.item() + # D outputs + self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) + self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) + + # D_grad outputs + self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach()) + self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach()) + + def test(self): + self.netG.eval() + with torch.no_grad(): + self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L) + + self.netG.train() + + def get_current_log(self, step): + return self.log_dict + + def get_current_visuals(self, need_HR=True): + out_dict = OrderedDict() + out_dict['LR'] = self.var_L.detach()[0].float().cpu() + + out_dict['SR'] = self.fake_H.detach()[0].float().cpu() + out_dict['SR_branch'] = self.fake_H_branch.detach()[0].float().cpu() + out_dict['LR_grad'] = self.grad_LR.detach()[0].float().cpu() + if need_HR: + out_dict['HR'] = self.var_H.detach()[0].float().cpu() + return out_dict + + def print_network(self): + # Generator + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel): + net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, + self.netG.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netG.__class__.__name__) + + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + if self.is_train: + # Disriminator + s, n = self.get_network_description(self.netD) + if isinstance(self.netD, nn.DataParallel): + net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, + self.netD.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netD.__class__.__name__) + + logger.info('Network D structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + if self.cri_fea: # F, Perceptual Network + s, n = self.get_network_description(self.netF) + if isinstance(self.netF, nn.DataParallel): + net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, + self.netF.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netF.__class__.__name__) + + logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + def load(self): + load_path_G = self.opt['path']['pretrain_model_G'] + if load_path_G is not None: + logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.netG) + load_path_D = self.opt['path']['pretrain_model_D'] + if self.opt['is_train'] and load_path_D is not None: + logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D)) + self.load_network(load_path_D, self.netD) + + def save(self, iter_step): + self.save_network(self.netG, 'G', iter_step) + self.save_network(self.netD, 'D', iter_step) + self.save_network(self.netD_grad, 'D_grad', iter_step) diff --git a/codes/models/SPSR_modules/__init__.py b/codes/models/SPSR_modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/SPSR_modules/architecture.py b/codes/models/SPSR_modules/architecture.py new file mode 100644 index 00000000..12d63b24 --- /dev/null +++ b/codes/models/SPSR_modules/architecture.py @@ -0,0 +1,654 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +from . import block as B +from . import spectral_norm as SN + + + +class Get_gradient_nopadding(nn.Module): + def __init__(self): + super(Get_gradient_nopadding, self).__init__() + kernel_v = [[0, -1, 0], + [0, 0, 0], + [0, 1, 0]] + kernel_h = [[0, 0, 0], + [-1, 0, 1], + [0, 0, 0]] + kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) + kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) + self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False) + + self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False) + + + def forward(self, x): + x_list = [] + for i in range(x.shape[1]): + x_i = x[:, i] + x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1) + x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1) + x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6) + x_list.append(x_i) + + x = torch.cat(x_list, dim = 1) + + return x + + +#################### +# Generator +#################### + +class SPSRNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \ + act_type='leakyrelu', mode='CNA', upsample_mode='upconv'): + super(SPSRNet, self).__init__() + + n_upscale = int(math.log(upscale, 2)) + + if upscale == 3: + n_upscale = 1 + + fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) + rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ + norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)] + + LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) + + if upsample_mode == 'upconv': + upsample_block = B.upconv_blcok + elif upsample_mode == 'pixelshuffle': + upsample_block = B.pixelshuffle_block + else: + raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) + if upscale == 3: + upsampler = upsample_block(nf, nf, 3, act_type=act_type) + else: + upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] + + self.HR_conv0_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) + self.HR_conv1_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None) + + self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\ + *upsampler, self.HR_conv0_new) + + self.get_g_nopadding = Get_gradient_nopadding() + + self.b_fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) + + self.b_concat_1 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) + self.b_block_1 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ + norm_type=norm_type, act_type=act_type, mode='CNA') + + + self.b_concat_2 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) + self.b_block_2 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ + norm_type=norm_type, act_type=act_type, mode='CNA') + + + self.b_concat_3 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) + self.b_block_3 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ + norm_type=norm_type, act_type=act_type, mode='CNA') + + + self.b_concat_4 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) + self.b_block_4 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ + norm_type=norm_type, act_type=act_type, mode='CNA') + + self.b_LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) + + if upsample_mode == 'upconv': + upsample_block = B.upconv_blcok + elif upsample_mode == 'pixelshuffle': + upsample_block = B.pixelshuffle_block + else: + raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) + if upscale == 3: + b_upsampler = upsample_block(nf, nf, 3, act_type=act_type) + else: + b_upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] + + b_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) + b_HR_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None) + + self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1) + + self.conv_w = B.conv_block(nf, out_nc, kernel_size=1, norm_type=None, act_type=None) + + self.f_concat = B.conv_block(nf*2, nf, kernel_size=3, norm_type=None, act_type=None) + + self.f_block = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ + norm_type=norm_type, act_type=act_type, mode='CNA') + + self.f_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) + self.f_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None) + + + def forward(self, x): + + x_grad = self.get_g_nopadding(x) + x = self.model[0](x) + + x, block_list = self.model[1](x) + + x_ori = x + for i in range(5): + x = block_list[i](x) + x_fea1 = x + + for i in range(5): + x = block_list[i+5](x) + x_fea2 = x + + for i in range(5): + x = block_list[i+10](x) + x_fea3 = x + + for i in range(5): + x = block_list[i+15](x) + x_fea4 = x + + x = block_list[20:](x) + #short cut + x = x_ori+x + x= self.model[2:](x) + x = self.HR_conv1_new(x) + + x_b_fea = self.b_fea_conv(x_grad) + x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1) + + x_cat_1 = self.b_block_1(x_cat_1) + x_cat_1 = self.b_concat_1(x_cat_1) + + x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1) + + x_cat_2 = self.b_block_2(x_cat_2) + x_cat_2 = self.b_concat_2(x_cat_2) + + x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1) + + x_cat_3 = self.b_block_3(x_cat_3) + x_cat_3 = self.b_concat_3(x_cat_3) + + x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1) + + x_cat_4 = self.b_block_4(x_cat_4) + x_cat_4 = self.b_concat_4(x_cat_4) + + x_cat_4 = self.b_LR_conv(x_cat_4) + + #short cut + x_cat_4 = x_cat_4+x_b_fea + x_branch = self.b_module(x_cat_4) + + x_out_branch = self.conv_w(x_branch) + ######## + x_branch_d = x_branch + x_f_cat = torch.cat([x_branch_d, x], dim=1) + x_f_cat = self.f_block(x_f_cat) + x_out = self.f_concat(x_f_cat) + x_out = self.f_HR_conv0(x_out) + x_out = self.f_HR_conv1(x_out) + + ######### + return x_out_branch, x_out, x_grad + + +#################### +# Discriminator +#################### + + +# VGG style Discriminator with input size 128*128 +class Discriminator_VGG_128(nn.Module): + def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): + super(Discriminator_VGG_128, self).__init__() + # features + # hxw, c + # 128, 64 + conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ + mode=mode) + conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 64, 64 + conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 32, 128 + conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 16, 256 + conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 8, 512 + conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 4, 512 + self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\ + conv9) + + # classifier + self.classifier = nn.Sequential( + nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + +# VGG style Discriminator with input size 96*96 +class Discriminator_VGG_96(nn.Module): + def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): + super(Discriminator_VGG_96, self).__init__() + # features + # hxw, c + # 96, 3 + conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ + mode=mode) + conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 48, 64 + conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 24, 128 + conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 12, 256 + conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 6, 512 + self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7) + + # classifier + self.classifier = nn.Sequential( + nn.Linear(512 * 6 * 6, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + +# VGG style Discriminator with input size 64*64 +class Discriminator_VGG_64(nn.Module): + def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): + super(Discriminator_VGG_64, self).__init__() + # features + # hxw, c + # 64, 3 + conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ + mode=mode) + conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 32, 64 + conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 16, 128 + conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 8, 256 + conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 4, 512 + self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7) + + # classifier + self.classifier = nn.Sequential( + nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + +# VGG style Discriminator with input size 32*32 +class Discriminator_VGG_32(nn.Module): + def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): + super(Discriminator_VGG_32, self).__init__() + # features + # hxw, c + # 32, 3 + conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ + mode=mode) + conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 16, 64 + conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 8, 128 + conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 4, 256 + self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5) + + # classifier + self.classifier = nn.Sequential( + nn.Linear(256 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + +# VGG style Discriminator with input size 16*16 +class Discriminator_VGG_16(nn.Module): + def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): + super(Discriminator_VGG_16, self).__init__() + # features + # hxw, c + # 16, 3 + conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ + mode=mode) + conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 8, 64 + conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 4, 128 + self.features = B.sequential(conv0, conv1, conv2, conv3) + + # classifier + self.classifier = nn.Sequential( + nn.Linear(128 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + +# VGG style Discriminator with input size 128*128, Spectral Normalization +class Discriminator_VGG_128_SN(nn.Module): + def __init__(self): + super(Discriminator_VGG_128_SN, self).__init__() + # features + # hxw, c + # 128, 64 + self.lrelu = nn.LeakyReLU(0.2, True) + + self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1)) + self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1)) + # 64, 64 + self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1)) + self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1)) + # 32, 128 + self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1)) + self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1)) + # 16, 256 + self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1)) + self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1)) + # 8, 512 + self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1)) + self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1)) + # 4, 512 + + # classifier + self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100)) + self.linear1 = SN.spectral_norm(nn.Linear(100, 1)) + + def forward(self, x): + x = self.lrelu(self.conv0(x)) + x = self.lrelu(self.conv1(x)) + x = self.lrelu(self.conv2(x)) + x = self.lrelu(self.conv3(x)) + x = self.lrelu(self.conv4(x)) + x = self.lrelu(self.conv5(x)) + x = self.lrelu(self.conv6(x)) + x = self.lrelu(self.conv7(x)) + x = self.lrelu(self.conv8(x)) + x = self.lrelu(self.conv9(x)) + x = x.view(x.size(0), -1) + x = self.lrelu(self.linear0(x)) + x = self.linear1(x) + return x + + +class Discriminator_VGG_96(nn.Module): + def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): + super(Discriminator_VGG_96, self).__init__() + # features + # hxw, c + # 96, 64 + conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ + mode=mode) + conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 48, 64 + conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 24, 128 + conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 12, 256 + conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 6, 512 + conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 3, 512 + self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\ + conv9) + + # classifier + self.classifier = nn.Sequential( + nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + +class Discriminator_VGG_192(nn.Module): + def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): + super(Discriminator_VGG_192, self).__init__() + # features + # hxw, c + # 192, 64 + conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ + mode=mode) + conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 96, 64 + conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 48, 128 + conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 24, 256 + conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 12, 512 + conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 6, 512 + conv10 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ + act_type=act_type, mode=mode) + conv11 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ + act_type=act_type, mode=mode) + # 3, 512 + self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\ + conv9, conv10, conv11) + + # classifier + self.classifier = nn.Sequential( + nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + +#################### +# Perceptual Network +#################### + +class VGGFeatureExtractor(nn.Module): + def __init__(self, + feature_layer=34, + use_bn=False, + use_input_norm=True, + device=torch.device('cpu')): + super(VGGFeatureExtractor, self).__init__() + if use_bn: + model = torchvision.models.vgg19_bn(pretrained=True) + else: + model = torchvision.models.vgg19(pretrained=True) + self.use_input_norm = use_input_norm + if self.use_input_norm: + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + # [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1] + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + # [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1] + self.register_buffer('mean', mean) + self.register_buffer('std', std) + self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) + # No need to BP to variable + for k, v in self.features.named_parameters(): + v.requires_grad = False + + def forward(self, x): + if self.use_input_norm: + x = (x - self.mean) / self.std + output = self.features(x) + return output + + +class ResNet101FeatureExtractor(nn.Module): + def __init__(self, use_input_norm=True, device=torch.device('cpu')): + super(ResNet101FeatureExtractor, self).__init__() + model = torchvision.models.resnet101(pretrained=True) + self.use_input_norm = use_input_norm + if self.use_input_norm: + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + # [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1] + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + # [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1] + self.register_buffer('mean', mean) + self.register_buffer('std', std) + self.features = nn.Sequential(*list(model.children())[:8]) + # No need to BP to variable + for k, v in self.features.named_parameters(): + v.requires_grad = False + + def forward(self, x): + if self.use_input_norm: + x = (x - self.mean) / self.std + output = self.features(x) + return output + + +class MINCNet(nn.Module): + def __init__(self): + super(MINCNet, self).__init__() + self.ReLU = nn.ReLU(True) + self.conv11 = nn.Conv2d(3, 64, 3, 1, 1) + self.conv12 = nn.Conv2d(64, 64, 3, 1, 1) + self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True) + self.conv21 = nn.Conv2d(64, 128, 3, 1, 1) + self.conv22 = nn.Conv2d(128, 128, 3, 1, 1) + self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True) + self.conv31 = nn.Conv2d(128, 256, 3, 1, 1) + self.conv32 = nn.Conv2d(256, 256, 3, 1, 1) + self.conv33 = nn.Conv2d(256, 256, 3, 1, 1) + self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True) + self.conv41 = nn.Conv2d(256, 512, 3, 1, 1) + self.conv42 = nn.Conv2d(512, 512, 3, 1, 1) + self.conv43 = nn.Conv2d(512, 512, 3, 1, 1) + self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True) + self.conv51 = nn.Conv2d(512, 512, 3, 1, 1) + self.conv52 = nn.Conv2d(512, 512, 3, 1, 1) + self.conv53 = nn.Conv2d(512, 512, 3, 1, 1) + + def forward(self, x): + out = self.ReLU(self.conv11(x)) + out = self.ReLU(self.conv12(out)) + out = self.maxpool1(out) + out = self.ReLU(self.conv21(out)) + out = self.ReLU(self.conv22(out)) + out = self.maxpool2(out) + out = self.ReLU(self.conv31(out)) + out = self.ReLU(self.conv32(out)) + out = self.ReLU(self.conv33(out)) + out = self.maxpool3(out) + out = self.ReLU(self.conv41(out)) + out = self.ReLU(self.conv42(out)) + out = self.ReLU(self.conv43(out)) + out = self.maxpool4(out) + out = self.ReLU(self.conv51(out)) + out = self.ReLU(self.conv52(out)) + out = self.conv53(out) + return out + + +class MINCFeatureExtractor(nn.Module): + def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \ + device=torch.device('cpu')): + super(MINCFeatureExtractor, self).__init__() + + self.features = MINCNet() + self.features.load_state_dict( + torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True) + self.features.eval() + # No need to BP to variable + for k, v in self.features.named_parameters(): + v.requires_grad = False + + def forward(self, x): + output = self.features(x) + return output diff --git a/codes/models/SPSR_modules/block.py b/codes/models/SPSR_modules/block.py new file mode 100644 index 00000000..07123bca --- /dev/null +++ b/codes/models/SPSR_modules/block.py @@ -0,0 +1,258 @@ +from collections import OrderedDict +import torch +import torch.nn as nn + +#################### +# Basic blocks +#################### + + +def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1): + # helper selecting activation + # neg_slope: for leakyrelu and init of prelu + # n_prelu: for p_relu num_parameters + act_type = act_type.lower() + if act_type == 'relu': + layer = nn.ReLU(inplace) + elif act_type == 'leakyrelu': + layer = nn.LeakyReLU(neg_slope, inplace) + elif act_type == 'prelu': + layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) + else: + raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) + return layer + +def norm(norm_type, nc): + # helper selecting normalization layer + norm_type = norm_type.lower() + if norm_type == 'batch': + layer = nn.BatchNorm2d(nc, affine=True) + elif norm_type == 'instance': + layer = nn.InstanceNorm2d(nc, affine=False) + else: + raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type)) + return layer + +def pad(pad_type, padding): + # helper selecting padding layer + # if padding is 'zero', do by conv layers + pad_type = pad_type.lower() + if padding == 0: + return None + if pad_type == 'reflect': + layer = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + layer = nn.ReplicationPad2d(padding) + else: + raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type)) + return layer + + +def get_valid_padding(kernel_size, dilation): + kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) + padding = (kernel_size - 1) // 2 + return padding + + +class ConcatBlock(nn.Module): + # Concat the output of a submodule to its input + def __init__(self, submodule): + super(ConcatBlock, self).__init__() + self.sub = submodule + + def forward(self, x): + output = torch.cat((x, self.sub(x)), dim=1) + return output + + def __repr__(self): + tmpstr = 'Identity .. \n|' + modstr = self.sub.__repr__().replace('\n', '\n|') + tmpstr = tmpstr + modstr + return tmpstr + + +class ShortcutBlock(nn.Module): + #Elementwise sum the output of a submodule to its input + def __init__(self, submodule): + super(ShortcutBlock, self).__init__() + self.sub = submodule + + def forward(self, x): + return x, self.sub + + def __repr__(self): + tmpstr = 'Identity + \n|' + modstr = self.sub.__repr__().replace('\n', '\n|') + tmpstr = tmpstr + modstr + return tmpstr + + +def sequential(*args): + # Flatten Sequential. It unwraps nn.Sequential. + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError('sequential does not support OrderedDict input.') + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + +def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \ + pad_type='zero', norm_type=None, act_type='relu', mode='CNA'): + ''' + Conv layer with padding, normalization, activation + mode: CNA --> Conv -> Norm -> Act + NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) + ''' + assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode) + padding = get_valid_padding(kernel_size, dilation) + p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None + padding = padding if pad_type == 'zero' else 0 + + c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \ + dilation=dilation, bias=bias, groups=groups) + a = act(act_type) if act_type else None + if 'CNA' in mode: + n = norm(norm_type, out_nc) if norm_type else None + return sequential(p, c, n, a) + elif mode == 'NAC': + if norm_type is None and act_type is not None: + a = act(act_type, inplace=False) + # Important! + # input----ReLU(inplace)----Conv--+----output + # |________________________| + # inplace ReLU will modify the input, therefore wrong output + n = norm(norm_type, in_nc) if norm_type else None + return sequential(n, a, p, c) + + +#################### +# Useful blocks +#################### + +class ResNetBlock(nn.Module): + ''' + ResNet Block, 3-3 style + with extra residual scaling used in EDSR + (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17) + ''' + + def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \ + bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1): + super(ResNetBlock, self).__init__() + conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \ + norm_type, act_type, mode) + if mode == 'CNA': + act_type = None + if mode == 'CNAC': # Residual path: |-CNAC-| + act_type = None + norm_type = None + conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \ + norm_type, act_type, mode) + # if in_nc != out_nc: + # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \ + # None, None) + # print('Need a projecter in ResNetBlock.') + # else: + # self.project = lambda x:x + self.res = sequential(conv0, conv1) + self.res_scale = res_scale + + def forward(self, x): + res = self.res(x).mul(self.res_scale) + return x + res + + +class ResidualDenseBlock_5C(nn.Module): + ''' + Residual Dense Block + style: 5 convs + The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) + ''' + + def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ + norm_type=None, act_type='leakyrelu', mode='CNA'): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \ + norm_type=norm_type, act_type=act_type, mode=mode) + self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \ + norm_type=norm_type, act_type=act_type, mode=mode) + self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \ + norm_type=norm_type, act_type=act_type, mode=mode) + self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \ + norm_type=norm_type, act_type=act_type, mode=mode) + if mode == 'CNA': + last_act = None + else: + last_act = act_type + self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \ + norm_type=norm_type, act_type=last_act, mode=mode) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(torch.cat((x, x1), 1)) + x3 = self.conv3(torch.cat((x, x1, x2), 1)) + x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5.mul(0.2) + x + + +class RRDB(nn.Module): + ''' + Residual in Residual Dense Block + (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) + ''' + + def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ + norm_type=None, act_type='leakyrelu', mode='CNA'): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \ + norm_type, act_type, mode) + self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \ + norm_type, act_type, mode) + self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \ + norm_type, act_type, mode) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out.mul(0.2) + x + + +#################### +# Upsampler +#################### + + +def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \ + pad_type='zero', norm_type=None, act_type='relu'): + ''' + Pixel shuffle layer + (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional + Neural Network, CVPR17) + ''' + conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \ + pad_type=pad_type, norm_type=None, act_type=None) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + + n = norm(norm_type, out_nc) if norm_type else None + a = act(act_type) if act_type else None + return sequential(conv, pixel_shuffle, n, a) + + +def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \ + pad_type='zero', norm_type=None, act_type='relu', mode='nearest'): + # Up conv + # described in https://distill.pub/2016/deconv-checkerboard/ + upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode) + conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \ + pad_type=pad_type, norm_type=norm_type, act_type=act_type) + return sequential(upsample, conv) diff --git a/codes/models/SPSR_modules/loss.py b/codes/models/SPSR_modules/loss.py new file mode 100644 index 00000000..25c2d765 --- /dev/null +++ b/codes/models/SPSR_modules/loss.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn + + +# Define GAN loss: [vanilla | lsgan | wgan-gp] +class GANLoss(nn.Module): + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type.lower() + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan-gp': + + def wgan_loss(input, target): + # target is boolean + return -1 * input.mean() if target else input.mean() + + self.loss = wgan_loss + else: + raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) + + def get_target_label(self, input, target_is_real): + if self.gan_type == 'wgan-gp': + return target_is_real + if target_is_real: + return torch.empty_like(input).fill_(self.real_label_val) + else: + return torch.empty_like(input).fill_(self.fake_label_val) + + def forward(self, input, target_is_real): + target_label = self.get_target_label(input, target_is_real) + loss = self.loss(input, target_label) + return loss + + +class GradientPenaltyLoss(nn.Module): + def __init__(self, device=torch.device('cpu')): + super(GradientPenaltyLoss, self).__init__() + self.register_buffer('grad_outputs', torch.Tensor()) + self.grad_outputs = self.grad_outputs.to(device) + + def get_grad_outputs(self, input): + if self.grad_outputs.size() != input.size(): + self.grad_outputs.resize_(input.size()).fill_(1.0) + return self.grad_outputs + + def forward(self, interp, interp_crit): + grad_outputs = self.get_grad_outputs(interp_crit) + grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, \ + grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True)[0] + grad_interp = grad_interp.view(grad_interp.size(0), -1) + grad_interp_norm = grad_interp.norm(2, dim=1) + + loss = ((grad_interp_norm - 1)**2).mean() + return loss diff --git a/codes/models/SPSR_modules/sampler.py b/codes/models/SPSR_modules/sampler.py new file mode 100644 index 00000000..1e817667 --- /dev/null +++ b/codes/models/SPSR_modules/sampler.py @@ -0,0 +1,94 @@ +import random +import torch +import numpy as np + +def _get_random_crop_indices(crop_region, crop_size): + ''' + crop_region: (strat_y, end_y, start_x, end_x) + crop_size: (y, x) + ''' + region_size = (crop_region[1] - crop_region[0], crop_region[3] - crop_region[2]) + if region_size[0] < crop_size[0] or region_size[1] < crop_size[1]: + print(region_size, crop_size) + assert region_size[0] >= crop_size[0] and region_size[1] >= crop_size[1] + if region_size[0] == crop_size[0]: + start_y = crop_region[0] + else: + start_y = random.choice(range(crop_region[0], crop_region[1] - crop_size[0])) + if region_size[1] == crop_size[1]: + start_x = crop_region[2] + else: + start_x = random.choice(range(crop_region[2], crop_region[3] - crop_size[1])) + return start_y, start_y + crop_size[0], start_x, start_x + crop_size[1] + +def _get_adaptive_crop_indices(crop_region, crop_size, num_candidate, dist_map, min_diff=False): + candidates = [_get_random_crop_indices(crop_region, crop_size) for _ in range(num_candidate)] + max_choice = candidates[0] + min_choice = candidates[0] + max_dist = 0 + min_dist = np.infty + with torch.no_grad(): + for c in candidates: + start_y, end_y, start_x, end_x = c + dist = torch.sum(dist_map[start_y: end_y, start_x: end_x]) + if dist > max_dist: + max_dist = dist + max_choice = c + if dist < min_dist: + min_dist = dist + min_choice = c + if min_diff: + return min_choice + else: + return max_choice + +def get_split_list(divisor, dividend): + split_list = [dividend // divisor for _ in range(divisor - 1)] + split_list.append(dividend - (dividend // divisor) * (divisor - 1)) + return split_list + +def random_sampler(pic_size, crop_dict): + crop_region = (0, pic_size[0], 0, pic_size[1]) + crop_res_dict = {} + for k, v in crop_dict.items(): + crop_size = (int(k), int(k)) + crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)] + return crop_res_dict + +def region_sampler(crop_region, crop_dict): + crop_res_dict = {} + for k, v in crop_dict.items(): + crop_size = (int(k), int(k)) + crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)] + return crop_res_dict + +def adaptive_sampler(pic_size, crop_dict, num_candidate_dict, dist_map, min_diff=False): + crop_region = (0, pic_size[0], 0, pic_size[1]) + crop_res_dict = {} + for k, v in crop_dict.items(): + crop_size = (int(k), int(k)) + crop_res_dict[k] = [_get_adaptive_crop_indices(crop_region, crop_size, num_candidate_dict[k], dist_map, min_diff) for _ in range(v)] + return crop_res_dict + +# TODO more flexible +def pyramid_sampler(pic_size, crop_dict): + crop_res_dict = {} + sorted_key = list(crop_dict.keys()) + sorted_key.sort(key=lambda x: int(x), reverse=True) + k = sorted_key[0] + crop_size = (int(k), int(k)) + crop_region = (0, pic_size[0], 0, pic_size[1]) + crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(crop_dict[k])] + + for i in range(1, len(sorted_key)): + crop_res_dict[sorted_key[i]] = [] + afore_num = crop_dict[sorted_key[i-1]] + new_num = crop_dict[sorted_key[i]] + split_list = get_split_list(afore_num, new_num) + crop_size = (int(sorted_key[i]), int(sorted_key[i])) + for j in range(len(split_list)): + crop_region = crop_res_dict[sorted_key[i-1]][j] + crop_res_dict[sorted_key[i]].extend([_get_random_crop_indices(crop_region, crop_size) for _ in range(split_list[j])]) + + return crop_res_dict + diff --git a/codes/models/SPSR_modules/spectral_norm.py b/codes/models/SPSR_modules/spectral_norm.py new file mode 100644 index 00000000..3ca2b636 --- /dev/null +++ b/codes/models/SPSR_modules/spectral_norm.py @@ -0,0 +1,149 @@ +''' +Copy from pytorch github repo +Spectral Normalization from https://arxiv.org/abs/1802.05957 +''' +import torch +from torch.nn.functional import normalize +from torch.nn.parameter import Parameter + + +class SpectralNorm(object): + def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError('Expected n_power_iterations to be positive, but ' + 'got n_power_iterations={}'.format(n_power_iterations)) + self.n_power_iterations = n_power_iterations + self.eps = eps + + def compute_weight(self, module): + weight = getattr(module, self.name + '_orig') + u = getattr(module, self.name + '_u') + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute(self.dim, + *[d for d in range(weight_mat.dim()) if d != self.dim]) + height = weight_mat.size(0) + weight_mat = weight_mat.reshape(height, -1) + with torch.no_grad(): + for _ in range(self.n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps) + u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps) + + sigma = torch.dot(u, torch.matmul(weight_mat, v)) + weight = weight / sigma + return weight, u + + def remove(self, module): + weight = getattr(module, self.name) + delattr(module, self.name) + delattr(module, self.name + '_u') + delattr(module, self.name + '_orig') + module.register_parameter(self.name, torch.nn.Parameter(weight)) + + def __call__(self, module, inputs): + if module.training: + weight, u = self.compute_weight(module) + setattr(module, self.name, weight) + setattr(module, self.name + '_u', u) + else: + r_g = getattr(module, self.name + '_orig').requires_grad + getattr(module, self.name).detach_().requires_grad_(r_g) + + @staticmethod + def apply(module, name, n_power_iterations, dim, eps): + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + height = weight.size(dim) + + u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps) + delattr(module, fn.name) + module.register_parameter(fn.name + "_orig", weight) + # We still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an nn.Parameter and + # gets added as a parameter. Instead, we register weight.data as a + # buffer, which will cause weight to be included in the state dict + # and also supports nn.init due to shared storage. + module.register_buffer(fn.name, weight.data) + module.register_buffer(fn.name + "_u", u) + + module.register_forward_pre_hook(fn) + return fn + + +def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None): + r"""Applies spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\ + \sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generaive Adversarial Networks (GANs) by rescaling the weight tensor + with spectral norm :math:`\sigma` of the weight matrix calculated using + power iteration method. If the dimension of the weight tensor is greater + than 2, it is reshaped to 2D in power iteration method to get spectral + norm. This is implemented via a hook that calculates spectral norm and + rescales weight before every :meth:`~Module.forward` call. + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + n_power_iterations (int, optional): number of power iterations to + calculate spectal norm + eps (float, optional): epsilon for numerical stability in + calculating norms + dim (int, optional): dimension corresponding to number of outputs, + the default is 0, except for modules that are instances of + ConvTranspose1/2/3d, when it is 1 + + Returns: + The original module with the spectal norm hook + + Example:: + + >>> m = spectral_norm(nn.Linear(20, 40)) + Linear (20 -> 40) + >>> m.weight_u.size() + torch.Size([20]) + + """ + if dim is None: + if isinstance( + module, + (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + return module + + +def remove_spectral_norm(module, name='weight'): + r"""Removes the spectral normalization reparameterization from a module. + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = spectral_norm(nn.Linear(40, 10)) + >>> remove_spectral_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError("spectral_norm of '{}' not found in {}".format(name, module)) diff --git a/codes/models/SPSR_networks.py b/codes/models/SPSR_networks.py new file mode 100644 index 00000000..575a378d --- /dev/null +++ b/codes/models/SPSR_networks.py @@ -0,0 +1,161 @@ +import functools +import logging +import torch +import torch.nn as nn +from torch.nn import init + +import models.SPSR_modules.architecture as arch +logger = logging.getLogger('base') +#################### +# initialize +#################### + + +def weights_init_normal(m, std=0.02): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + init.normal_(m.weight.data, 0.0, std) + if m.bias is not None: + m.bias.data.zero_() + elif classname.find('Linear') != -1: + init.normal_(m.weight.data, 0.0, std) + if m.bias is not None: + m.bias.data.zero_() + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, std) # BN also uses norm + init.constant_(m.bias.data, 0.0) + + +def weights_init_kaiming(m, scale=1): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif classname.find('Linear') != -1: + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif classname.find('BatchNorm2d') != -1: + if m.affine != False: + + init.constant_(m.weight.data, 1.0) + init.constant_(m.bias.data, 0.0) + + +def weights_init_orthogonal(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + init.orthogonal_(m.weight.data, gain=1) + if m.bias is not None: + m.bias.data.zero_() + elif classname.find('Linear') != -1: + init.orthogonal_(m.weight.data, gain=1) + if m.bias is not None: + m.bias.data.zero_() + elif classname.find('BatchNorm2d') != -1: + init.constant_(m.weight.data, 1.0) + init.constant_(m.bias.data, 0.0) + + +def init_weights(net, init_type='kaiming', scale=1, std=0.02): + # scale for 'kaiming', std for 'normal'. + if init_type == 'normal': + weights_init_normal_ = functools.partial(weights_init_normal, std=std) + net.apply(weights_init_normal_) + elif init_type == 'kaiming': + weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale) + net.apply(weights_init_kaiming_) + elif init_type == 'orthogonal': + net.apply(weights_init_orthogonal) + else: + raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type)) + + +#################### +# define network +#################### + + +# Generator +def define_G(opt, device=None): + opt_net = opt['network_G'] + which_model = opt_net['which_model_G'] + + if which_model == 'spsr_net': + netG = arch.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], + nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], + act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') + else: + raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) + + if opt['is_train']: + init_weights(netG, init_type='kaiming', scale=0.1) + + return netG + + + +# Discriminator +def define_D(opt): + opt_net = opt['network_D'] + which_model = opt_net['which_model_D'] + + if which_model == 'discriminator_vgg_128': + netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ + norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) + + elif which_model == 'discriminator_vgg_96': + netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ + norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) + elif which_model == 'discriminator_vgg_192': + netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ + norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) + elif which_model == 'discriminator_vgg_128_SN': + netD = arch.Discriminator_VGG_128_SN() + else: + raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) + + init_weights(netD, init_type='kaiming', scale=1) + + return netD + +def define_D_grad(opt): + opt_net = opt['network_D'] + which_model = opt_net['which_model_D'] + + if which_model == 'discriminator_vgg_128': + netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ + norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) + + elif which_model == 'discriminator_vgg_96': + netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ + norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) + elif which_model == 'discriminator_vgg_192': + netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ + norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) + elif which_model == 'discriminator_vgg_128_SN': + netD = arch.Discriminator_VGG_128_SN() + else: + raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) + + + init_weights(netD, init_type='kaiming', scale=1) + + return netD + + +def define_F(opt, use_bn=False): + device = torch.device('cuda') + # pytorch pretrained VGG19-54, before ReLU. + if use_bn: + feature_layer = 49 + else: + feature_layer = 34 + netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \ + use_input_norm=True, device=device) + + netF.eval() + return netF diff --git a/codes/models/__init__.py b/codes/models/__init__.py index 3dae848c..1697b87a 100644 --- a/codes/models/__init__.py +++ b/codes/models/__init__.py @@ -11,6 +11,8 @@ def create_model(opt): from .SRGAN_model import SRGANModel as M elif model == 'feat': from .feature_model import FeatureModel as M + if model == 'spsr': + from .SPSR_model import SPSRModel as M else: raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) m = M(opt) diff --git a/codes/train.py b/codes/train.py index e8e5ec53..4162cb38 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg4_lr_feat.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)