From c8ab89d2439420d669ada15244eab329c43aa741 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 13 May 2020 16:53:38 -0600 Subject: [PATCH] Add model swapout Model swapout is a feature where, at specified intervals, a random D and G model will be swapped in place for the one being trained. After a short period of time, this model is swapped back out. This is intended to increase training diversity. --- codes/models/SRGAN_model.py | 61 ++++++++++++++++++++++- codes/models/base_model.py | 1 + codes/options/train/train_vix_corrupt.yml | 9 ++-- 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 1b715b05..d1283232 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -9,6 +9,8 @@ from models.base_model import BaseModel from models.loss import GANLoss from apex import amp import torch.nn.functional as F +import glob +import random import torchvision.utils as utils import os @@ -32,7 +34,10 @@ class SRGANModel(BaseModel): if 'network_C' in opt.keys(): self.netC = networks.define_G(opt, net_key='network_C').to(self.device) + # The corruptor net is fixed. Lock 'her down. self.netC.eval() + for p in self.netC.parameters(): + p.requires_grad = True else: self.netC = None @@ -147,6 +152,13 @@ class SRGANModel(BaseModel): self.log_dict = OrderedDict() + # Swapout params + self.swapout_G_freq = train_opt['swapout_G_freq'] if train_opt['swapout_G_freq'] else 0 + self.swapout_G_duration = 0 + self.swapout_D_freq = train_opt['swapout_D_freq'] if train_opt['swapout_D_freq'] else 0 + self.swapout_D_duration = 0 + self.swapout_duration = train_opt['swapout_duration'] if train_opt['swapout_duration'] else 0 + self.print_network() # print network self.load() # load G and D if needed @@ -174,6 +186,9 @@ class SRGANModel(BaseModel): if step > self.D_init_iters: self.optimizer_G.zero_grad() + self.swapout_D(step) + self.swapout_G(step) + # Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason. if step % self.D_update_ratio == 0 and step > self.D_init_iters: for p in self.netG.parameters(): @@ -248,7 +263,12 @@ class SRGANModel(BaseModel): noise = torch.randn_like(var_ref[0]) * noise_theta noise.to(self.device) self.optimizer_D.zero_grad() - for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, var_ref_skips, self.pix, self.fake_GenOut): + for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix): + # Re-compute generator outputs (post-update). + with torch.no_grad(): + fake_H = self.netG(var_L) + fake_H = (fake_H[0].detach(), fake_H[1].detach(), fake_H[2].detach()) + # Apply noise to the inputs to slow discriminator convergence. var_ref = (var_ref[0] + noise,) + var_ref[1:] fake_H = (fake_H[0] + noise,) + fake_H[1:] @@ -345,6 +365,45 @@ class SRGANModel(BaseModel): lo_skip = F.interpolate(truth_img, scale_factor=.25) return med_skip, lo_skip + def pick_rand_prev_model(self, model_suffix): + previous_models = glob.glob(os.path.join(self.opt['path']['models'], "*_%s.pth" % (model_suffix,))) + if len(previous_models) <= 1: + return None + # Just a note: this intentionally includes the swap model in the list of possibilities. + return previous_models[random.randint(0, len(previous_models)-1)] + + def swapout_D(self, step): + if self.swapout_D_duration > 0: + self.swapout_D_duration -= 1 + if self.swapout_D_duration == 0: + # Swap back. + print("Swapping back to current D model: %s" % (self.stashed_D,)) + self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load']) + self.stashed_D = None + elif self.swapout_D_freq != 0 and step % self.swapout_D_freq == 0: + swapped_model = self.pick_rand_prev_model('D') + if swapped_model is not None: + print("Swapping to previous D model: %s" % (swapped_model,)) + self.stashed_D = self.save_network(self.netD, 'D', 'swap_model') + self.load_network(swapped_model, self.netD, self.opt['path']['strict_load']) + self.swapout_D_duration = self.swapout_duration + + def swapout_G(self, step): + if self.swapout_G_duration > 0: + self.swapout_G_duration -= 1 + if self.swapout_G_duration == 0: + # Swap back. + print("Swapping back to current G model: %s" % (self.stashed_G,)) + self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load']) + self.stashed_G = None + elif self.swapout_G_freq != 0 and step % self.swapout_G_freq == 0: + swapped_model = self.pick_rand_prev_model('G') + if swapped_model is not None: + print("Swapping to previous G model: %s" % (swapped_model,)) + self.stashed_G = self.save_network(self.netG, 'G', 'swap_model') + self.load_network(swapped_model, self.netG, self.opt['path']['strict_load']) + self.swapout_G_duration = self.swapout_duration + def test(self): self.netG.eval() with torch.no_grad(): diff --git a/codes/models/base_model.py b/codes/models/base_model.py index 7d9dfdb1..f5013921 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -81,6 +81,7 @@ class BaseModel(): for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) + return save_path def load_network(self, load_path, network, strict=True): if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): diff --git a/codes/options/train/train_vix_corrupt.yml b/codes/options/train/train_vix_corrupt.yml index 8ae35573..3c65379c 100644 --- a/codes/options/train/train_vix_corrupt.yml +++ b/codes/options/train/train_vix_corrupt.yml @@ -37,7 +37,7 @@ network_G: nb_denoiser: 20 nb_upsampler: 0 upscale_applications: 0 - inject_noise: True + inject_noise: False network_D: which_model_D: discriminator_vgg_128 @@ -48,7 +48,7 @@ network_D: path: pretrain_model_G: ~ pretrain_model_D: ~ - resume_state: ../experiments/train_vix_corrupt/training_state/31000.state + resume_state: ~ strict_load: true #### training settings: learning rate scheme, loss @@ -81,6 +81,9 @@ train: gan_type: ragan # gan | ragan gan_weight: .1 mega_batch_factor: 1 + swapout_G_freq: 113 + swapout_D_freq: 223 + swapout_duration: 40 D_update_ratio: 1 D_init_iters: -1 @@ -91,4 +94,4 @@ train: #### logger logger: print_freq: 50 - save_checkpoint_freq: !!float 5e2 + save_checkpoint_freq: 500