Add model swapout
Model swapout is a feature where, at specified intervals, a random D and G model will be swapped in place for the one being trained. After a short period of time, this model is swapped back out. This is intended to increase training diversity.
This commit is contained in:
parent
c336d32fd3
commit
c8ab89d243
|
@ -9,6 +9,8 @@ from models.base_model import BaseModel
|
|||
from models.loss import GANLoss
|
||||
from apex import amp
|
||||
import torch.nn.functional as F
|
||||
import glob
|
||||
import random
|
||||
|
||||
import torchvision.utils as utils
|
||||
import os
|
||||
|
@ -32,7 +34,10 @@ class SRGANModel(BaseModel):
|
|||
|
||||
if 'network_C' in opt.keys():
|
||||
self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
|
||||
# The corruptor net is fixed. Lock 'her down.
|
||||
self.netC.eval()
|
||||
for p in self.netC.parameters():
|
||||
p.requires_grad = True
|
||||
else:
|
||||
self.netC = None
|
||||
|
||||
|
@ -147,6 +152,13 @@ class SRGANModel(BaseModel):
|
|||
|
||||
self.log_dict = OrderedDict()
|
||||
|
||||
# Swapout params
|
||||
self.swapout_G_freq = train_opt['swapout_G_freq'] if train_opt['swapout_G_freq'] else 0
|
||||
self.swapout_G_duration = 0
|
||||
self.swapout_D_freq = train_opt['swapout_D_freq'] if train_opt['swapout_D_freq'] else 0
|
||||
self.swapout_D_duration = 0
|
||||
self.swapout_duration = train_opt['swapout_duration'] if train_opt['swapout_duration'] else 0
|
||||
|
||||
self.print_network() # print network
|
||||
self.load() # load G and D if needed
|
||||
|
||||
|
@ -174,6 +186,9 @@ class SRGANModel(BaseModel):
|
|||
if step > self.D_init_iters:
|
||||
self.optimizer_G.zero_grad()
|
||||
|
||||
self.swapout_D(step)
|
||||
self.swapout_G(step)
|
||||
|
||||
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
for p in self.netG.parameters():
|
||||
|
@ -248,7 +263,12 @@ class SRGANModel(BaseModel):
|
|||
noise = torch.randn_like(var_ref[0]) * noise_theta
|
||||
noise.to(self.device)
|
||||
self.optimizer_D.zero_grad()
|
||||
for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, var_ref_skips, self.pix, self.fake_GenOut):
|
||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix):
|
||||
# Re-compute generator outputs (post-update).
|
||||
with torch.no_grad():
|
||||
fake_H = self.netG(var_L)
|
||||
fake_H = (fake_H[0].detach(), fake_H[1].detach(), fake_H[2].detach())
|
||||
|
||||
# Apply noise to the inputs to slow discriminator convergence.
|
||||
var_ref = (var_ref[0] + noise,) + var_ref[1:]
|
||||
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
||||
|
@ -345,6 +365,45 @@ class SRGANModel(BaseModel):
|
|||
lo_skip = F.interpolate(truth_img, scale_factor=.25)
|
||||
return med_skip, lo_skip
|
||||
|
||||
def pick_rand_prev_model(self, model_suffix):
|
||||
previous_models = glob.glob(os.path.join(self.opt['path']['models'], "*_%s.pth" % (model_suffix,)))
|
||||
if len(previous_models) <= 1:
|
||||
return None
|
||||
# Just a note: this intentionally includes the swap model in the list of possibilities.
|
||||
return previous_models[random.randint(0, len(previous_models)-1)]
|
||||
|
||||
def swapout_D(self, step):
|
||||
if self.swapout_D_duration > 0:
|
||||
self.swapout_D_duration -= 1
|
||||
if self.swapout_D_duration == 0:
|
||||
# Swap back.
|
||||
print("Swapping back to current D model: %s" % (self.stashed_D,))
|
||||
self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load'])
|
||||
self.stashed_D = None
|
||||
elif self.swapout_D_freq != 0 and step % self.swapout_D_freq == 0:
|
||||
swapped_model = self.pick_rand_prev_model('D')
|
||||
if swapped_model is not None:
|
||||
print("Swapping to previous D model: %s" % (swapped_model,))
|
||||
self.stashed_D = self.save_network(self.netD, 'D', 'swap_model')
|
||||
self.load_network(swapped_model, self.netD, self.opt['path']['strict_load'])
|
||||
self.swapout_D_duration = self.swapout_duration
|
||||
|
||||
def swapout_G(self, step):
|
||||
if self.swapout_G_duration > 0:
|
||||
self.swapout_G_duration -= 1
|
||||
if self.swapout_G_duration == 0:
|
||||
# Swap back.
|
||||
print("Swapping back to current G model: %s" % (self.stashed_G,))
|
||||
self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load'])
|
||||
self.stashed_G = None
|
||||
elif self.swapout_G_freq != 0 and step % self.swapout_G_freq == 0:
|
||||
swapped_model = self.pick_rand_prev_model('G')
|
||||
if swapped_model is not None:
|
||||
print("Swapping to previous G model: %s" % (swapped_model,))
|
||||
self.stashed_G = self.save_network(self.netG, 'G', 'swap_model')
|
||||
self.load_network(swapped_model, self.netG, self.opt['path']['strict_load'])
|
||||
self.swapout_G_duration = self.swapout_duration
|
||||
|
||||
def test(self):
|
||||
self.netG.eval()
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -81,6 +81,7 @@ class BaseModel():
|
|||
for key, param in state_dict.items():
|
||||
state_dict[key] = param.cpu()
|
||||
torch.save(state_dict, save_path)
|
||||
return save_path
|
||||
|
||||
def load_network(self, load_path, network, strict=True):
|
||||
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||
|
|
|
@ -37,7 +37,7 @@ network_G:
|
|||
nb_denoiser: 20
|
||||
nb_upsampler: 0
|
||||
upscale_applications: 0
|
||||
inject_noise: True
|
||||
inject_noise: False
|
||||
|
||||
network_D:
|
||||
which_model_D: discriminator_vgg_128
|
||||
|
@ -48,7 +48,7 @@ network_D:
|
|||
path:
|
||||
pretrain_model_G: ~
|
||||
pretrain_model_D: ~
|
||||
resume_state: ../experiments/train_vix_corrupt/training_state/31000.state
|
||||
resume_state: ~
|
||||
strict_load: true
|
||||
|
||||
#### training settings: learning rate scheme, loss
|
||||
|
@ -81,6 +81,9 @@ train:
|
|||
gan_type: ragan # gan | ragan
|
||||
gan_weight: .1
|
||||
mega_batch_factor: 1
|
||||
swapout_G_freq: 113
|
||||
swapout_D_freq: 223
|
||||
swapout_duration: 40
|
||||
|
||||
D_update_ratio: 1
|
||||
D_init_iters: -1
|
||||
|
@ -91,4 +94,4 @@ train:
|
|||
#### logger
|
||||
logger:
|
||||
print_freq: 50
|
||||
save_checkpoint_freq: !!float 5e2
|
||||
save_checkpoint_freq: 500
|
||||
|
|
Loading…
Reference in New Issue
Block a user