Add model swapout
Model swapout is a feature where, at specified intervals, a random D and G model will be swapped in place for the one being trained. After a short period of time, this model is swapped back out. This is intended to increase training diversity.
This commit is contained in:
parent
c336d32fd3
commit
c8ab89d243
|
@ -9,6 +9,8 @@ from models.base_model import BaseModel
|
||||||
from models.loss import GANLoss
|
from models.loss import GANLoss
|
||||||
from apex import amp
|
from apex import amp
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import glob
|
||||||
|
import random
|
||||||
|
|
||||||
import torchvision.utils as utils
|
import torchvision.utils as utils
|
||||||
import os
|
import os
|
||||||
|
@ -32,7 +34,10 @@ class SRGANModel(BaseModel):
|
||||||
|
|
||||||
if 'network_C' in opt.keys():
|
if 'network_C' in opt.keys():
|
||||||
self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
|
self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
|
||||||
|
# The corruptor net is fixed. Lock 'her down.
|
||||||
self.netC.eval()
|
self.netC.eval()
|
||||||
|
for p in self.netC.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
else:
|
else:
|
||||||
self.netC = None
|
self.netC = None
|
||||||
|
|
||||||
|
@ -147,6 +152,13 @@ class SRGANModel(BaseModel):
|
||||||
|
|
||||||
self.log_dict = OrderedDict()
|
self.log_dict = OrderedDict()
|
||||||
|
|
||||||
|
# Swapout params
|
||||||
|
self.swapout_G_freq = train_opt['swapout_G_freq'] if train_opt['swapout_G_freq'] else 0
|
||||||
|
self.swapout_G_duration = 0
|
||||||
|
self.swapout_D_freq = train_opt['swapout_D_freq'] if train_opt['swapout_D_freq'] else 0
|
||||||
|
self.swapout_D_duration = 0
|
||||||
|
self.swapout_duration = train_opt['swapout_duration'] if train_opt['swapout_duration'] else 0
|
||||||
|
|
||||||
self.print_network() # print network
|
self.print_network() # print network
|
||||||
self.load() # load G and D if needed
|
self.load() # load G and D if needed
|
||||||
|
|
||||||
|
@ -174,6 +186,9 @@ class SRGANModel(BaseModel):
|
||||||
if step > self.D_init_iters:
|
if step > self.D_init_iters:
|
||||||
self.optimizer_G.zero_grad()
|
self.optimizer_G.zero_grad()
|
||||||
|
|
||||||
|
self.swapout_D(step)
|
||||||
|
self.swapout_G(step)
|
||||||
|
|
||||||
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
|
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
|
||||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||||
for p in self.netG.parameters():
|
for p in self.netG.parameters():
|
||||||
|
@ -248,7 +263,12 @@ class SRGANModel(BaseModel):
|
||||||
noise = torch.randn_like(var_ref[0]) * noise_theta
|
noise = torch.randn_like(var_ref[0]) * noise_theta
|
||||||
noise.to(self.device)
|
noise.to(self.device)
|
||||||
self.optimizer_D.zero_grad()
|
self.optimizer_D.zero_grad()
|
||||||
for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, var_ref_skips, self.pix, self.fake_GenOut):
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix):
|
||||||
|
# Re-compute generator outputs (post-update).
|
||||||
|
with torch.no_grad():
|
||||||
|
fake_H = self.netG(var_L)
|
||||||
|
fake_H = (fake_H[0].detach(), fake_H[1].detach(), fake_H[2].detach())
|
||||||
|
|
||||||
# Apply noise to the inputs to slow discriminator convergence.
|
# Apply noise to the inputs to slow discriminator convergence.
|
||||||
var_ref = (var_ref[0] + noise,) + var_ref[1:]
|
var_ref = (var_ref[0] + noise,) + var_ref[1:]
|
||||||
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
||||||
|
@ -345,6 +365,45 @@ class SRGANModel(BaseModel):
|
||||||
lo_skip = F.interpolate(truth_img, scale_factor=.25)
|
lo_skip = F.interpolate(truth_img, scale_factor=.25)
|
||||||
return med_skip, lo_skip
|
return med_skip, lo_skip
|
||||||
|
|
||||||
|
def pick_rand_prev_model(self, model_suffix):
|
||||||
|
previous_models = glob.glob(os.path.join(self.opt['path']['models'], "*_%s.pth" % (model_suffix,)))
|
||||||
|
if len(previous_models) <= 1:
|
||||||
|
return None
|
||||||
|
# Just a note: this intentionally includes the swap model in the list of possibilities.
|
||||||
|
return previous_models[random.randint(0, len(previous_models)-1)]
|
||||||
|
|
||||||
|
def swapout_D(self, step):
|
||||||
|
if self.swapout_D_duration > 0:
|
||||||
|
self.swapout_D_duration -= 1
|
||||||
|
if self.swapout_D_duration == 0:
|
||||||
|
# Swap back.
|
||||||
|
print("Swapping back to current D model: %s" % (self.stashed_D,))
|
||||||
|
self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load'])
|
||||||
|
self.stashed_D = None
|
||||||
|
elif self.swapout_D_freq != 0 and step % self.swapout_D_freq == 0:
|
||||||
|
swapped_model = self.pick_rand_prev_model('D')
|
||||||
|
if swapped_model is not None:
|
||||||
|
print("Swapping to previous D model: %s" % (swapped_model,))
|
||||||
|
self.stashed_D = self.save_network(self.netD, 'D', 'swap_model')
|
||||||
|
self.load_network(swapped_model, self.netD, self.opt['path']['strict_load'])
|
||||||
|
self.swapout_D_duration = self.swapout_duration
|
||||||
|
|
||||||
|
def swapout_G(self, step):
|
||||||
|
if self.swapout_G_duration > 0:
|
||||||
|
self.swapout_G_duration -= 1
|
||||||
|
if self.swapout_G_duration == 0:
|
||||||
|
# Swap back.
|
||||||
|
print("Swapping back to current G model: %s" % (self.stashed_G,))
|
||||||
|
self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load'])
|
||||||
|
self.stashed_G = None
|
||||||
|
elif self.swapout_G_freq != 0 and step % self.swapout_G_freq == 0:
|
||||||
|
swapped_model = self.pick_rand_prev_model('G')
|
||||||
|
if swapped_model is not None:
|
||||||
|
print("Swapping to previous G model: %s" % (swapped_model,))
|
||||||
|
self.stashed_G = self.save_network(self.netG, 'G', 'swap_model')
|
||||||
|
self.load_network(swapped_model, self.netG, self.opt['path']['strict_load'])
|
||||||
|
self.swapout_G_duration = self.swapout_duration
|
||||||
|
|
||||||
def test(self):
|
def test(self):
|
||||||
self.netG.eval()
|
self.netG.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
@ -81,6 +81,7 @@ class BaseModel():
|
||||||
for key, param in state_dict.items():
|
for key, param in state_dict.items():
|
||||||
state_dict[key] = param.cpu()
|
state_dict[key] = param.cpu()
|
||||||
torch.save(state_dict, save_path)
|
torch.save(state_dict, save_path)
|
||||||
|
return save_path
|
||||||
|
|
||||||
def load_network(self, load_path, network, strict=True):
|
def load_network(self, load_path, network, strict=True):
|
||||||
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||||
|
|
|
@ -37,7 +37,7 @@ network_G:
|
||||||
nb_denoiser: 20
|
nb_denoiser: 20
|
||||||
nb_upsampler: 0
|
nb_upsampler: 0
|
||||||
upscale_applications: 0
|
upscale_applications: 0
|
||||||
inject_noise: True
|
inject_noise: False
|
||||||
|
|
||||||
network_D:
|
network_D:
|
||||||
which_model_D: discriminator_vgg_128
|
which_model_D: discriminator_vgg_128
|
||||||
|
@ -48,7 +48,7 @@ network_D:
|
||||||
path:
|
path:
|
||||||
pretrain_model_G: ~
|
pretrain_model_G: ~
|
||||||
pretrain_model_D: ~
|
pretrain_model_D: ~
|
||||||
resume_state: ../experiments/train_vix_corrupt/training_state/31000.state
|
resume_state: ~
|
||||||
strict_load: true
|
strict_load: true
|
||||||
|
|
||||||
#### training settings: learning rate scheme, loss
|
#### training settings: learning rate scheme, loss
|
||||||
|
@ -81,6 +81,9 @@ train:
|
||||||
gan_type: ragan # gan | ragan
|
gan_type: ragan # gan | ragan
|
||||||
gan_weight: .1
|
gan_weight: .1
|
||||||
mega_batch_factor: 1
|
mega_batch_factor: 1
|
||||||
|
swapout_G_freq: 113
|
||||||
|
swapout_D_freq: 223
|
||||||
|
swapout_duration: 40
|
||||||
|
|
||||||
D_update_ratio: 1
|
D_update_ratio: 1
|
||||||
D_init_iters: -1
|
D_init_iters: -1
|
||||||
|
@ -91,4 +94,4 @@ train:
|
||||||
#### logger
|
#### logger
|
||||||
logger:
|
logger:
|
||||||
print_freq: 50
|
print_freq: 50
|
||||||
save_checkpoint_freq: !!float 5e2
|
save_checkpoint_freq: 500
|
||||||
|
|
Loading…
Reference in New Issue
Block a user