diff --git a/codes/models/SPSR_model.py b/codes/models/SPSR_model.py index 9c89b556..f26f2d13 100644 --- a/codes/models/SPSR_model.py +++ b/codes/models/SPSR_model.py @@ -5,10 +5,13 @@ from collections import OrderedDict import torch import torch.nn as nn from torch.optim import lr_scheduler +from apex import amp import models.SPSR_networks as networks from .base_model import BaseModel -from models.SPSR_modules.loss import GANLoss, GradientPenaltyLoss +from models.SPSR_modules.loss import GANLoss +import torchvision.utils as utils + logger = logging.getLogger('base') import torch.nn.functional as F @@ -95,10 +98,13 @@ class SPSRModel(BaseModel): self.netG.train() self.netD.train() self.netD_grad.train() + self.mega_batch_factor = 1 self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: + self.mega_batch_factor = train_opt['mega_batch_factor'] + # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] @@ -139,12 +145,6 @@ class SPSRModel(BaseModel): self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt['Branch_pretrain'] else 0 self.Branch_init_iters = train_opt['Branch_init_iters'] if train_opt['Branch_init_iters'] else 1 - if train_opt['gan_type'] == 'wgan-gp': - self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) - # gradient penalty loss - self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device) - self.l_gp_w = train_opt['gp_weigth'] - # gradient_pixel_loss if train_opt['gradient_pixel_weight'] > 0: self.cri_pix_grad = nn.MSELoss().to(self.device) @@ -202,6 +202,12 @@ class SPSRModel(BaseModel): self.optimizers.append(self.optimizer_D_grad) + # AMP + [self.netG, self.netD, self.netD_grad], [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \ + amp.initialize([self.netG, self.netD, self.netD_grad], + [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad], + opt_level=self.amp_level, num_losses=3) + # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: @@ -216,12 +222,12 @@ class SPSRModel(BaseModel): def feed_data(self, data, need_HR=True): # LR - self.var_L = data['LQ'].to(self.device) + self.var_L = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0)] if need_HR: # train or val - self.var_H = data['GT'].to(self.device) + self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] input_ref = data['ref'] if 'ref' in data else data['GT'] - self.var_ref = input_ref.to(self.device) + self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref.to(self.device), chunks=self.mega_batch_factor, dim=0)] @@ -247,118 +253,118 @@ class SPSRModel(BaseModel): self.optimizer_G.zero_grad() - self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L) + self.fake_H_branch = [] + self.fake_H = [] + self.grad_LR = [] + for var_L, var_H, var_ref in zip(self.var_L, self.var_H, self.var_ref): + fake_H_branch, fake_H, grad_LR = self.netG(var_L) + self.fake_H_branch.append(fake_H_branch.detach()) + self.fake_H.append(fake_H.detach()) + self.grad_LR.append(grad_LR.detach()) - - self.fake_H_grad = self.get_grad(self.fake_H) - self.var_H_grad = self.get_grad(self.var_H) - self.var_ref_grad = self.get_grad(self.var_ref) - self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H) - + fake_H_grad = self.get_grad(fake_H) + var_H_grad = self.get_grad(var_H) + var_ref_grad = self.get_grad(var_ref) + var_H_grad_nopadding = self.get_grad_nopadding(var_H) + + l_g_total = 0 + if step % self.D_update_ratio == 0 and step > self.D_init_iters: + if self.cri_pix: # pixel loss + l_g_pix = self.l_pix_w * self.cri_pix(fake_H, var_H) + l_g_total += l_g_pix + if self.cri_fea: # feature loss + real_fea = self.netF(var_H).detach() + fake_fea = self.netF(fake_H) + l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) + l_g_total += l_g_fea + + if self.cri_pix_grad: #gradient pixel loss + l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad) + l_g_total += l_g_pix_grad + + if self.cri_pix_branch: #branch pixel loss + l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(fake_H_branch, var_H_grad_nopadding) + l_g_total += l_g_pix_grad_branch + + if self.l_gan_w > 0: + # G gan + cls loss + pred_g_fake = self.netD(fake_H) + pred_d_real = self.netD(var_ref).detach() + + l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 + l_g_total += l_g_gan + + if self.cri_grad_gan: + # grad G gan + cls loss + pred_g_fake_grad = self.netD_grad(fake_H_grad) + pred_d_real_grad = self.netD_grad(var_ref_grad).detach() + + l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) + + self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2 + l_g_total += l_g_gan_grad + + l_g_total /= self.mega_batch_factor + with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled: + l_g_total_scaled.backward() - l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: - if self.cri_pix: # pixel loss - l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) - l_g_total += l_g_pix - if self.cri_fea: # feature loss - real_fea = self.netF(self.var_H).detach() - fake_fea = self.netF(self.fake_H) - l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) - l_g_total += l_g_fea - - if self.cri_pix_grad: #gradient pixel loss - l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(self.fake_H_grad, self.var_H_grad) - l_g_total += l_g_pix_grad - - - if self.cri_pix_branch: #branch pixel loss - l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(self.fake_H_branch, self.var_H_grad_nopadding) - l_g_total += l_g_pix_grad_branch - - - # G gan + cls loss - pred_g_fake = self.netD(self.fake_H) - pred_d_real = self.netD(self.var_ref).detach() - - l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + - self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 - l_g_total += l_g_gan - - # grad G gan + cls loss - - pred_g_fake_grad = self.netD_grad(self.fake_H_grad) - pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach() - - l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) + - self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2 - l_g_total += l_g_gan_grad - - - l_g_total.backward() self.optimizer_G.step() - # D - for p in self.netD.parameters(): - p.requires_grad = True + if self.l_gan_w > 0: + # D + for p in self.netD.parameters(): + p.requires_grad = True - self.optimizer_D.zero_grad() - l_d_total = 0 - pred_d_real = self.netD(self.var_ref) - pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G - l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) - l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) + self.optimizer_D.zero_grad() + for var_ref, fake_H in zip(self.var_ref, self.fake_H): + pred_d_real = self.netD(var_ref) + pred_d_fake = self.netD(fake_H) # detach to avoid BP to G + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) + l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) - l_d_total = (l_d_real + l_d_fake) / 2 + l_d_total = (l_d_real + l_d_fake) / 2 - if self.opt['train']['gan_type'] == 'wgan-gp': - batch_size = self.var_ref.size(0) - if self.random_pt.size(0) != batch_size: - self.random_pt.resize_(batch_size, 1, 1, 1) - self.random_pt.uniform_() # Draw random interpolation points - interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_ref - interp.requires_grad = True - interp_crit, _ = self.netD(interp) - l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit) - l_d_total += l_d_gp + l_d_total /= self.mega_batch_factor + with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: + l_d_total_scaled.backward() - l_d_total.backward() + self.optimizer_D.step() - self.optimizer_D.step() + if self.cri_grad_gan: + for p in self.netD_grad.parameters(): + p.requires_grad = True - - for p in self.netD_grad.parameters(): - p.requires_grad = True + self.optimizer_D_grad.zero_grad() + for var_ref, fake_H in zip(self.var_ref, self.fake_H): + fake_H_grad = self.get_grad(fake_H) + var_ref_grad = self.get_grad(var_ref) - self.optimizer_D_grad.zero_grad() - l_d_total_grad = 0 + pred_d_real_grad = self.netD_grad(var_ref_grad) + pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G - - pred_d_real_grad = self.netD_grad(self.var_ref_grad) - pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach()) # detach to avoid BP to G - - l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True) - l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) + l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True) + l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) - l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2 + l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2 + l_d_total_grad /= self.mega_batch_factor + with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled: + l_d_total_grad_scaled.backward() - l_d_total_grad.backward() - - self.optimizer_D_grad.step() + self.optimizer_D_grad.step() # Log sample images from first microbatch. if step % 50 == 0: - import torchvision.utils as utils sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp") os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True) # fed_LQ is not chunked. - utils.save_image(self.var_H.cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,))) - utils.save_image(self.var_L.cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,))) - utils.save_image(self.fake_H.cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,))) + utils.save_image(self.var_H[0].cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,))) + utils.save_image(self.var_L[0].cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,))) + utils.save_image(self.fake_H[0].cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,))) # set log @@ -368,33 +374,42 @@ class SPSRModel(BaseModel): self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() - self.log_dict['l_g_gan'] = l_g_gan.item() + if self.l_gan_w > 0: + self.log_dict['l_g_gan'] = l_g_gan.item() if self.cri_pix_branch: #branch pixel loss self.log_dict['l_g_pix_grad_branch'] = l_g_pix_grad_branch.item() - - # D - self.log_dict['l_d_real'] = l_d_real.item() - self.log_dict['l_d_fake'] = l_d_fake.item() - # D_grad - self.log_dict['l_d_real_grad'] = l_d_real_grad.item() - self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item() + if self.l_gan_w > 0: + # D + self.log_dict['l_d_real'] = l_d_real.item() + self.log_dict['l_d_fake'] = l_d_fake.item() - if self.opt['train']['gan_type'] == 'wgan-gp': - self.log_dict['l_d_gp'] = l_d_gp.item() - # D outputs - self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) - self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) + # D_grad + self.log_dict['l_d_real_grad'] = l_d_real_grad.item() + self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item() - # D_grad outputs - self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach()) - self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach()) + if self.opt['train']['gan_type'] == 'wgan-gp': + self.log_dict['l_d_gp'] = l_d_gp.item() + # D outputs + self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) + self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) + + # D_grad outputs + self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach()) + self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach()) def test(self): self.netG.eval() with torch.no_grad(): - self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L) + self.fake_H_branch = [] + self.fake_H = [] + self.grad_LR = [] + for var_L in self.var_L: + fake_H_branch, fake_H, grad_LR = self.netG(var_L) + self.fake_H_branch.append(fake_H_branch) + self.fake_H.append(fake_H) + self.grad_LR.append(grad_LR) self.netG.train() @@ -403,13 +418,13 @@ class SPSRModel(BaseModel): def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() - out_dict['LR'] = self.var_L.detach()[0].float().cpu() + out_dict['LR'] = self.var_L[0].float().cpu() - out_dict['SR'] = self.fake_H.detach()[0].float().cpu() - out_dict['SR_branch'] = self.fake_H_branch.detach()[0].float().cpu() - out_dict['LR_grad'] = self.grad_LR.detach()[0].float().cpu() + out_dict['rlt'] = self.fake_H[0].float().cpu() + out_dict['SR_branch'] = self.fake_H_branch[0].float().cpu() + out_dict['LR_grad'] = self.grad_LR[0].float().cpu() if need_HR: - out_dict['HR'] = self.var_H.detach()[0].float().cpu() + out_dict['GT'] = self.var_H[0].float().cpu() return out_dict def print_network(self): @@ -456,7 +471,31 @@ class SPSRModel(BaseModel): logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD) + def compute_fea_loss(self, real, fake): + if self.cri_fea is None: + return 0 + with torch.no_grad(): + real = real.unsqueeze(dim=0).to(self.device) + fake = fake.unsqueeze(dim=0).to(self.device) + real_fea = self.netF(real).detach() + fake_fea = self.netF(fake) + return self.cri_fea(fake_fea, real_fea).item() + + def force_restore_swapout(self): + pass + def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step) self.save_network(self.netD_grad, 'D_grad', iter_step) + + # override of load_network that allows loading partial params (like RRDB_PSNR_x4) + def load_network(self, load_path, network, strict=True): + if isinstance(network, nn.DataParallel): + network = network.module + pretrained_dict = torch.load(load_path) + model_dict = network.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + + model_dict.update(pretrained_dict) + network.load_state_dict(model_dict) \ No newline at end of file diff --git a/codes/models/SPSR_modules/architecture.py b/codes/models/SPSR_modules/architecture.py index 12d63b24..6928aa0c 100644 --- a/codes/models/SPSR_modules/architecture.py +++ b/codes/models/SPSR_modules/architecture.py @@ -4,8 +4,6 @@ import torch.nn as nn import torch.nn.functional as F import torchvision from . import block as B -from . import spectral_norm as SN - class Get_gradient_nopadding(nn.Module): @@ -248,292 +246,6 @@ class Discriminator_VGG_128(nn.Module): return x -# VGG style Discriminator with input size 96*96 -class Discriminator_VGG_96(nn.Module): - def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): - super(Discriminator_VGG_96, self).__init__() - # features - # hxw, c - # 96, 3 - conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ - mode=mode) - conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 48, 64 - conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 24, 128 - conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 12, 256 - conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 6, 512 - self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7) - - # classifier - self.classifier = nn.Sequential( - nn.Linear(512 * 6 * 6, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) - - def forward(self, x): - x = self.features(x) - x = x.view(x.size(0), -1) - x = self.classifier(x) - return x - - -# VGG style Discriminator with input size 64*64 -class Discriminator_VGG_64(nn.Module): - def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): - super(Discriminator_VGG_64, self).__init__() - # features - # hxw, c - # 64, 3 - conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ - mode=mode) - conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 32, 64 - conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 16, 128 - conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 8, 256 - conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 4, 512 - self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7) - - # classifier - self.classifier = nn.Sequential( - nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) - - def forward(self, x): - x = self.features(x) - x = x.view(x.size(0), -1) - x = self.classifier(x) - return x - - -# VGG style Discriminator with input size 32*32 -class Discriminator_VGG_32(nn.Module): - def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): - super(Discriminator_VGG_32, self).__init__() - # features - # hxw, c - # 32, 3 - conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ - mode=mode) - conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 16, 64 - conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 8, 128 - conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 4, 256 - self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5) - - # classifier - self.classifier = nn.Sequential( - nn.Linear(256 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) - - def forward(self, x): - x = self.features(x) - x = x.view(x.size(0), -1) - x = self.classifier(x) - return x - - -# VGG style Discriminator with input size 16*16 -class Discriminator_VGG_16(nn.Module): - def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): - super(Discriminator_VGG_16, self).__init__() - # features - # hxw, c - # 16, 3 - conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ - mode=mode) - conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 8, 64 - conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 4, 128 - self.features = B.sequential(conv0, conv1, conv2, conv3) - - # classifier - self.classifier = nn.Sequential( - nn.Linear(128 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) - - def forward(self, x): - x = self.features(x) - x = x.view(x.size(0), -1) - x = self.classifier(x) - return x - - -# VGG style Discriminator with input size 128*128, Spectral Normalization -class Discriminator_VGG_128_SN(nn.Module): - def __init__(self): - super(Discriminator_VGG_128_SN, self).__init__() - # features - # hxw, c - # 128, 64 - self.lrelu = nn.LeakyReLU(0.2, True) - - self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1)) - self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1)) - # 64, 64 - self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1)) - self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1)) - # 32, 128 - self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1)) - self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1)) - # 16, 256 - self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1)) - self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1)) - # 8, 512 - self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1)) - self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1)) - # 4, 512 - - # classifier - self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100)) - self.linear1 = SN.spectral_norm(nn.Linear(100, 1)) - - def forward(self, x): - x = self.lrelu(self.conv0(x)) - x = self.lrelu(self.conv1(x)) - x = self.lrelu(self.conv2(x)) - x = self.lrelu(self.conv3(x)) - x = self.lrelu(self.conv4(x)) - x = self.lrelu(self.conv5(x)) - x = self.lrelu(self.conv6(x)) - x = self.lrelu(self.conv7(x)) - x = self.lrelu(self.conv8(x)) - x = self.lrelu(self.conv9(x)) - x = x.view(x.size(0), -1) - x = self.lrelu(self.linear0(x)) - x = self.linear1(x) - return x - - -class Discriminator_VGG_96(nn.Module): - def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): - super(Discriminator_VGG_96, self).__init__() - # features - # hxw, c - # 96, 64 - conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ - mode=mode) - conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 48, 64 - conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 24, 128 - conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 12, 256 - conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 6, 512 - conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 3, 512 - self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\ - conv9) - - # classifier - self.classifier = nn.Sequential( - nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) - - def forward(self, x): - x = self.features(x) - x = x.view(x.size(0), -1) - x = self.classifier(x) - return x - - -class Discriminator_VGG_192(nn.Module): - def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): - super(Discriminator_VGG_192, self).__init__() - # features - # hxw, c - # 192, 64 - conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ - mode=mode) - conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 96, 64 - conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 48, 128 - conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 24, 256 - conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 12, 512 - conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 6, 512 - conv10 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ - act_type=act_type, mode=mode) - conv11 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ - act_type=act_type, mode=mode) - # 3, 512 - self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\ - conv9, conv10, conv11) - - # classifier - self.classifier = nn.Sequential( - nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) - - def forward(self, x): - x = self.features(x) - x = x.view(x.size(0), -1) - x = self.classifier(x) - return x - - #################### # Perceptual Network #################### diff --git a/codes/models/SPSR_modules/block.py b/codes/models/SPSR_modules/block.py index 07123bca..dfaa2246 100644 --- a/codes/models/SPSR_modules/block.py +++ b/codes/models/SPSR_modules/block.py @@ -29,6 +29,8 @@ def norm(norm_type, nc): layer = nn.BatchNorm2d(nc, affine=True) elif norm_type == 'instance': layer = nn.InstanceNorm2d(nc, affine=False) + elif norm_type == 'group': + layer = nn.GroupNorm(8, nc) else: raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type)) return layer diff --git a/codes/models/SPSR_modules/loss.py b/codes/models/SPSR_modules/loss.py index 25c2d765..10dacddf 100644 --- a/codes/models/SPSR_modules/loss.py +++ b/codes/models/SPSR_modules/loss.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -# Define GAN loss: [vanilla | lsgan | wgan-gp] +# Define GAN loss: [vanilla | lsgan] class GANLoss(nn.Module): def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): super(GANLoss, self).__init__() @@ -14,19 +14,10 @@ class GANLoss(nn.Module): self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == 'lsgan': self.loss = nn.MSELoss() - elif self.gan_type == 'wgan-gp': - - def wgan_loss(input, target): - # target is boolean - return -1 * input.mean() if target else input.mean() - - self.loss = wgan_loss else: raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) def get_target_label(self, input, target_is_real): - if self.gan_type == 'wgan-gp': - return target_is_real if target_is_real: return torch.empty_like(input).fill_(self.real_label_val) else: @@ -36,25 +27,3 @@ class GANLoss(nn.Module): target_label = self.get_target_label(input, target_is_real) loss = self.loss(input, target_label) return loss - - -class GradientPenaltyLoss(nn.Module): - def __init__(self, device=torch.device('cpu')): - super(GradientPenaltyLoss, self).__init__() - self.register_buffer('grad_outputs', torch.Tensor()) - self.grad_outputs = self.grad_outputs.to(device) - - def get_grad_outputs(self, input): - if self.grad_outputs.size() != input.size(): - self.grad_outputs.resize_(input.size()).fill_(1.0) - return self.grad_outputs - - def forward(self, interp, interp_crit): - grad_outputs = self.get_grad_outputs(interp_crit) - grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, \ - grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True)[0] - grad_interp = grad_interp.view(grad_interp.size(0), -1) - grad_interp_norm = grad_interp.norm(2, dim=1) - - loss = ((grad_interp_norm - 1)**2).mean() - return loss diff --git a/codes/models/SPSR_modules/spectral_norm.py b/codes/models/SPSR_modules/spectral_norm.py deleted file mode 100644 index 3ca2b636..00000000 --- a/codes/models/SPSR_modules/spectral_norm.py +++ /dev/null @@ -1,149 +0,0 @@ -''' -Copy from pytorch github repo -Spectral Normalization from https://arxiv.org/abs/1802.05957 -''' -import torch -from torch.nn.functional import normalize -from torch.nn.parameter import Parameter - - -class SpectralNorm(object): - def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): - self.name = name - self.dim = dim - if n_power_iterations <= 0: - raise ValueError('Expected n_power_iterations to be positive, but ' - 'got n_power_iterations={}'.format(n_power_iterations)) - self.n_power_iterations = n_power_iterations - self.eps = eps - - def compute_weight(self, module): - weight = getattr(module, self.name + '_orig') - u = getattr(module, self.name + '_u') - weight_mat = weight - if self.dim != 0: - # permute dim to front - weight_mat = weight_mat.permute(self.dim, - *[d for d in range(weight_mat.dim()) if d != self.dim]) - height = weight_mat.size(0) - weight_mat = weight_mat.reshape(height, -1) - with torch.no_grad(): - for _ in range(self.n_power_iterations): - # Spectral norm of weight equals to `u^T W v`, where `u` and `v` - # are the first left and right singular vectors. - # This power iteration produces approximations of `u` and `v`. - v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps) - u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps) - - sigma = torch.dot(u, torch.matmul(weight_mat, v)) - weight = weight / sigma - return weight, u - - def remove(self, module): - weight = getattr(module, self.name) - delattr(module, self.name) - delattr(module, self.name + '_u') - delattr(module, self.name + '_orig') - module.register_parameter(self.name, torch.nn.Parameter(weight)) - - def __call__(self, module, inputs): - if module.training: - weight, u = self.compute_weight(module) - setattr(module, self.name, weight) - setattr(module, self.name + '_u', u) - else: - r_g = getattr(module, self.name + '_orig').requires_grad - getattr(module, self.name).detach_().requires_grad_(r_g) - - @staticmethod - def apply(module, name, n_power_iterations, dim, eps): - fn = SpectralNorm(name, n_power_iterations, dim, eps) - weight = module._parameters[name] - height = weight.size(dim) - - u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps) - delattr(module, fn.name) - module.register_parameter(fn.name + "_orig", weight) - # We still need to assign weight back as fn.name because all sorts of - # things may assume that it exists, e.g., when initializing weights. - # However, we can't directly assign as it could be an nn.Parameter and - # gets added as a parameter. Instead, we register weight.data as a - # buffer, which will cause weight to be included in the state dict - # and also supports nn.init due to shared storage. - module.register_buffer(fn.name, weight.data) - module.register_buffer(fn.name + "_u", u) - - module.register_forward_pre_hook(fn) - return fn - - -def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None): - r"""Applies spectral normalization to a parameter in the given module. - - .. math:: - \mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\ - \sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} - - Spectral normalization stabilizes the training of discriminators (critics) - in Generaive Adversarial Networks (GANs) by rescaling the weight tensor - with spectral norm :math:`\sigma` of the weight matrix calculated using - power iteration method. If the dimension of the weight tensor is greater - than 2, it is reshaped to 2D in power iteration method to get spectral - norm. This is implemented via a hook that calculates spectral norm and - rescales weight before every :meth:`~Module.forward` call. - - See `Spectral Normalization for Generative Adversarial Networks`_ . - - .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 - - Args: - module (nn.Module): containing module - name (str, optional): name of weight parameter - n_power_iterations (int, optional): number of power iterations to - calculate spectal norm - eps (float, optional): epsilon for numerical stability in - calculating norms - dim (int, optional): dimension corresponding to number of outputs, - the default is 0, except for modules that are instances of - ConvTranspose1/2/3d, when it is 1 - - Returns: - The original module with the spectal norm hook - - Example:: - - >>> m = spectral_norm(nn.Linear(20, 40)) - Linear (20 -> 40) - >>> m.weight_u.size() - torch.Size([20]) - - """ - if dim is None: - if isinstance( - module, - (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)): - dim = 1 - else: - dim = 0 - SpectralNorm.apply(module, name, n_power_iterations, dim, eps) - return module - - -def remove_spectral_norm(module, name='weight'): - r"""Removes the spectral normalization reparameterization from a module. - - Args: - module (nn.Module): containing module - name (str, optional): name of weight parameter - - Example: - >>> m = spectral_norm(nn.Linear(40, 10)) - >>> remove_spectral_norm(m) - """ - for k, hook in module._forward_pre_hooks.items(): - if isinstance(hook, SpectralNorm) and hook.name == name: - hook.remove(module) - del module._forward_pre_hooks[k] - return module - - raise ValueError("spectral_norm of '{}' not found in {}".format(name, module)) diff --git a/codes/models/loss.py b/codes/models/loss.py index 6698ad2b..b3f72563 100644 --- a/codes/models/loss.py +++ b/codes/models/loss.py @@ -18,7 +18,7 @@ class CharbonnierLoss(nn.Module): return loss -# Define GAN loss: [vanilla | lsgan | wgan-gp] +# Define GAN loss: [vanilla | lsgan] class GANLoss(nn.Module): def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): super(GANLoss, self).__init__() @@ -30,19 +30,10 @@ class GANLoss(nn.Module): self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == 'lsgan': self.loss = nn.MSELoss() - elif self.gan_type == 'wgan-gp': - - def wgan_loss(input, target): - # target is boolean - return -1 * input.mean() if target else input.mean() - - self.loss = wgan_loss else: raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) def get_target_label(self, input, target_is_real): - if self.gan_type == 'wgan-gp': - return target_is_real if target_is_real: return torch.empty_like(input).fill_(self.real_label_val) else: @@ -57,29 +48,6 @@ class GANLoss(nn.Module): return loss -class GradientPenaltyLoss(nn.Module): - def __init__(self, device=torch.device('cpu')): - super(GradientPenaltyLoss, self).__init__() - self.register_buffer('grad_outputs', torch.Tensor()) - self.grad_outputs = self.grad_outputs.to(device) - - def get_grad_outputs(self, input): - if self.grad_outputs.size() != input.size(): - self.grad_outputs.resize_(input.size()).fill_(1.0) - return self.grad_outputs - - def forward(self, interp, interp_crit): - grad_outputs = self.get_grad_outputs(interp_crit) - grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, - grad_outputs=grad_outputs, create_graph=True, - retain_graph=True, only_inputs=True)[0] - grad_interp = grad_interp.view(grad_interp.size(0), -1) - grad_interp_norm = grad_interp.norm(2, dim=1) - - loss = ((grad_interp_norm - 1)**2).mean() - return loss - - # Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL # Utilizes pre-computed perceptual_weights. To generate these from your dataset, see data_scripts/compute_fdpl_perceptual_weights.py # In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs). diff --git a/codes/train.py b/codes/train.py index 4162cb38..89d70ac7 100644 --- a/codes/train.py +++ b/codes/train.py @@ -215,7 +215,7 @@ def main(): logger.info(message) #### validation if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: - if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation + if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsr'] and rank <= 0: # image restoration validation model.force_restore_swapout() val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size'] # does not support multi-GPU validation