Add model swapout

Model swapout is a feature where, at specified intervals,
a random D and G model will be swapped in place for the
one being trained. After a short period of time, this model
is swapped back out. This is intended to increase training
diversity.
This commit is contained in:
James Betker 2020-05-13 16:53:38 -06:00
parent c336d32fd3
commit c8ab89d243
3 changed files with 67 additions and 4 deletions

View File

@ -9,6 +9,8 @@ from models.base_model import BaseModel
from models.loss import GANLoss from models.loss import GANLoss
from apex import amp from apex import amp
import torch.nn.functional as F import torch.nn.functional as F
import glob
import random
import torchvision.utils as utils import torchvision.utils as utils
import os import os
@ -32,7 +34,10 @@ class SRGANModel(BaseModel):
if 'network_C' in opt.keys(): if 'network_C' in opt.keys():
self.netC = networks.define_G(opt, net_key='network_C').to(self.device) self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
# The corruptor net is fixed. Lock 'her down.
self.netC.eval() self.netC.eval()
for p in self.netC.parameters():
p.requires_grad = True
else: else:
self.netC = None self.netC = None
@ -147,6 +152,13 @@ class SRGANModel(BaseModel):
self.log_dict = OrderedDict() self.log_dict = OrderedDict()
# Swapout params
self.swapout_G_freq = train_opt['swapout_G_freq'] if train_opt['swapout_G_freq'] else 0
self.swapout_G_duration = 0
self.swapout_D_freq = train_opt['swapout_D_freq'] if train_opt['swapout_D_freq'] else 0
self.swapout_D_duration = 0
self.swapout_duration = train_opt['swapout_duration'] if train_opt['swapout_duration'] else 0
self.print_network() # print network self.print_network() # print network
self.load() # load G and D if needed self.load() # load G and D if needed
@ -174,6 +186,9 @@ class SRGANModel(BaseModel):
if step > self.D_init_iters: if step > self.D_init_iters:
self.optimizer_G.zero_grad() self.optimizer_G.zero_grad()
self.swapout_D(step)
self.swapout_G(step)
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason. # Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
if step % self.D_update_ratio == 0 and step > self.D_init_iters: if step % self.D_update_ratio == 0 and step > self.D_init_iters:
for p in self.netG.parameters(): for p in self.netG.parameters():
@ -248,7 +263,12 @@ class SRGANModel(BaseModel):
noise = torch.randn_like(var_ref[0]) * noise_theta noise = torch.randn_like(var_ref[0]) * noise_theta
noise.to(self.device) noise.to(self.device)
self.optimizer_D.zero_grad() self.optimizer_D.zero_grad()
for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, var_ref_skips, self.pix, self.fake_GenOut): for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix):
# Re-compute generator outputs (post-update).
with torch.no_grad():
fake_H = self.netG(var_L)
fake_H = (fake_H[0].detach(), fake_H[1].detach(), fake_H[2].detach())
# Apply noise to the inputs to slow discriminator convergence. # Apply noise to the inputs to slow discriminator convergence.
var_ref = (var_ref[0] + noise,) + var_ref[1:] var_ref = (var_ref[0] + noise,) + var_ref[1:]
fake_H = (fake_H[0] + noise,) + fake_H[1:] fake_H = (fake_H[0] + noise,) + fake_H[1:]
@ -345,6 +365,45 @@ class SRGANModel(BaseModel):
lo_skip = F.interpolate(truth_img, scale_factor=.25) lo_skip = F.interpolate(truth_img, scale_factor=.25)
return med_skip, lo_skip return med_skip, lo_skip
def pick_rand_prev_model(self, model_suffix):
previous_models = glob.glob(os.path.join(self.opt['path']['models'], "*_%s.pth" % (model_suffix,)))
if len(previous_models) <= 1:
return None
# Just a note: this intentionally includes the swap model in the list of possibilities.
return previous_models[random.randint(0, len(previous_models)-1)]
def swapout_D(self, step):
if self.swapout_D_duration > 0:
self.swapout_D_duration -= 1
if self.swapout_D_duration == 0:
# Swap back.
print("Swapping back to current D model: %s" % (self.stashed_D,))
self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load'])
self.stashed_D = None
elif self.swapout_D_freq != 0 and step % self.swapout_D_freq == 0:
swapped_model = self.pick_rand_prev_model('D')
if swapped_model is not None:
print("Swapping to previous D model: %s" % (swapped_model,))
self.stashed_D = self.save_network(self.netD, 'D', 'swap_model')
self.load_network(swapped_model, self.netD, self.opt['path']['strict_load'])
self.swapout_D_duration = self.swapout_duration
def swapout_G(self, step):
if self.swapout_G_duration > 0:
self.swapout_G_duration -= 1
if self.swapout_G_duration == 0:
# Swap back.
print("Swapping back to current G model: %s" % (self.stashed_G,))
self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load'])
self.stashed_G = None
elif self.swapout_G_freq != 0 and step % self.swapout_G_freq == 0:
swapped_model = self.pick_rand_prev_model('G')
if swapped_model is not None:
print("Swapping to previous G model: %s" % (swapped_model,))
self.stashed_G = self.save_network(self.netG, 'G', 'swap_model')
self.load_network(swapped_model, self.netG, self.opt['path']['strict_load'])
self.swapout_G_duration = self.swapout_duration
def test(self): def test(self):
self.netG.eval() self.netG.eval()
with torch.no_grad(): with torch.no_grad():

View File

@ -81,6 +81,7 @@ class BaseModel():
for key, param in state_dict.items(): for key, param in state_dict.items():
state_dict[key] = param.cpu() state_dict[key] = param.cpu()
torch.save(state_dict, save_path) torch.save(state_dict, save_path)
return save_path
def load_network(self, load_path, network, strict=True): def load_network(self, load_path, network, strict=True):
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):

View File

@ -37,7 +37,7 @@ network_G:
nb_denoiser: 20 nb_denoiser: 20
nb_upsampler: 0 nb_upsampler: 0
upscale_applications: 0 upscale_applications: 0
inject_noise: True inject_noise: False
network_D: network_D:
which_model_D: discriminator_vgg_128 which_model_D: discriminator_vgg_128
@ -48,7 +48,7 @@ network_D:
path: path:
pretrain_model_G: ~ pretrain_model_G: ~
pretrain_model_D: ~ pretrain_model_D: ~
resume_state: ../experiments/train_vix_corrupt/training_state/31000.state resume_state: ~
strict_load: true strict_load: true
#### training settings: learning rate scheme, loss #### training settings: learning rate scheme, loss
@ -81,6 +81,9 @@ train:
gan_type: ragan # gan | ragan gan_type: ragan # gan | ragan
gan_weight: .1 gan_weight: .1
mega_batch_factor: 1 mega_batch_factor: 1
swapout_G_freq: 113
swapout_D_freq: 223
swapout_duration: 40
D_update_ratio: 1 D_update_ratio: 1
D_init_iters: -1 D_init_iters: -1
@ -91,4 +94,4 @@ train:
#### logger #### logger
logger: logger:
print_freq: 50 print_freq: 50
save_checkpoint_freq: !!float 5e2 save_checkpoint_freq: 500