Some work on extensible trainer

This commit is contained in:
James Betker 2020-08-18 08:49:32 -06:00
parent 0c98c61f4a
commit 74cdaa2226
4 changed files with 100 additions and 45 deletions

View File

@ -31,16 +31,16 @@ class ExtensibleTrainer(BaseModel):
train_opt = opt['train'] train_opt = opt['train']
self.mega_batch_factor = 1 self.mega_batch_factor = 1
self.netG = {} self.netsG = {}
self.netD = {} self.netsD = {}
self.networks = [] self.networks = []
for name, net in opt['networks'].items(): for name, net in opt['networks'].items():
if net['type'] == 'generator': if net['type'] == 'generator':
new_net = networks.define_G(net) new_net = networks.define_G(net)
self.netG[name] = new_net self.netsG[name] = new_net
elif net['type'] == 'discriminator': elif net['type'] == 'discriminator':
new_net = networks.define_D(net) new_net = networks.define_D(net)
self.netD[name] = new_net self.netsD[name] = new_net
else: else:
raise NotImplementedError("Can only handle generators and discriminators") raise NotImplementedError("Can only handle generators and discriminators")
self.networks.append(new_net) self.networks.append(new_net)
@ -74,7 +74,7 @@ class ExtensibleTrainer(BaseModel):
# Backpush the wrapped networks into the network dicts.. # Backpush the wrapped networks into the network dicts..
found = 0 found = 0
for dnet in dnets: for dnet in dnets:
for net_dict in [self.netD, self.netG]: for net_dict in [self.netsD, self.netsG]:
for k, v in net_dict.items(): for k, v in net_dict.items():
if v == dnet: if v == dnet:
net_dict[k] = dnet net_dict[k] = dnet
@ -84,7 +84,7 @@ class ExtensibleTrainer(BaseModel):
# Initialize the training steps # Initialize the training steps
self.steps = [] self.steps = []
for step in opt['steps']: for step in opt['steps']:
step = create_step(step, self.netG, self.netD) step = create_step(step, self.netsG, self.netsD)
self.steps.append(step) self.steps.append(step)
self.optimizers.extend(step.get_optimizers()) self.optimizers.extend(step.get_optimizers())
@ -119,7 +119,7 @@ class ExtensibleTrainer(BaseModel):
nets_to_train = s.get_networks_trained() nets_to_train = s.get_networks_trained()
for name, net in self.networks.items(): for name, net in self.networks.items():
net_enabled = name in nets_to_train net_enabled = name in nets_to_train
for p in self.netG.parameters(): for p in self.netsG.parameters():
if p.dtype != torch.int64 and p.dtype != torch.bool: if p.dtype != torch.int64 and p.dtype != torch.bool:
p.requires_grad = net_enabled p.requires_grad = net_enabled
else: else:
@ -135,7 +135,7 @@ class ExtensibleTrainer(BaseModel):
# G # G
for p in self.netD.parameters(): for p in self.netsD.parameters():
p.requires_grad = False p.requires_grad = False
if self.spsr_enabled: if self.spsr_enabled:
for p in self.netD_grad.parameters(): for p in self.netD_grad.parameters():
@ -147,15 +147,15 @@ class ExtensibleTrainer(BaseModel):
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason. # Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
if step % self.D_update_ratio == 0 and step >= self.D_init_iters: if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
if self.spsr_enabled and self.branch_pretrain and step < self.branch_init_iters: if self.spsr_enabled and self.branch_pretrain and step < self.branch_init_iters:
for k, v in self.netG.named_parameters(): for k, v in self.netsG.named_parameters():
if v.dtype != torch.int64 and v.dtype != torch.bool: if v.dtype != torch.int64 and v.dtype != torch.bool:
v.requires_grad = '_branch_pretrain' in k v.requires_grad = '_branch_pretrain' in k
else: else:
for p in self.netG.parameters(): for p in self.netsG.parameters():
if p.dtype != torch.int64 and p.dtype != torch.bool: if p.dtype != torch.int64 and p.dtype != torch.bool:
p.requires_grad = True p.requires_grad = True
else: else:
for p in self.netG.parameters(): for p in self.netsG.parameters():
p.requires_grad = False p.requires_grad = False
# Calculate a standard deviation for the gaussian noise to be applied to the discriminator, termed noise-theta. # Calculate a standard deviation for the gaussian noise to be applied to the discriminator, termed noise-theta.
@ -179,17 +179,17 @@ class ExtensibleTrainer(BaseModel):
if self.spsr_enabled: if self.spsr_enabled:
using_gan_img = False using_gan_img = False
# SPSR models have outputs from three different branches. # SPSR models have outputs from three different branches.
fake_H_branch, fake_GenOut, grad_LR = self.netG(var_L) fake_H_branch, fake_GenOut, grad_LR = self.netsG(var_L)
fea_GenOut = fake_GenOut fea_GenOut = fake_GenOut
self.spsr_grad_GenOut.append(fake_H_branch) self.spsr_grad_GenOut.append(fake_H_branch)
# Get image gradients for later use. # Get image gradients for later use.
fake_H_grad = self.get_grad_nopadding(fake_GenOut) fake_H_grad = self.get_grad_nopadding(fake_GenOut)
else: else:
if random.random() > self.gan_lq_img_use_prob: if random.random() > self.gan_lq_img_use_prob:
fea_GenOut, fake_GenOut = self.netG(var_L) fea_GenOut, fake_GenOut = self.netsG(var_L)
using_gan_img = False using_gan_img = False
else: else:
fea_GenOut, fake_GenOut = self.netG(var_LGAN) fea_GenOut, fake_GenOut = self.netsG(var_LGAN)
using_gan_img = True using_gan_img = True
if _profile: if _profile:
@ -262,13 +262,13 @@ class ExtensibleTrainer(BaseModel):
if self.l_gan_w > 0: if self.l_gan_w > 0:
if self.opt['train']['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: if self.opt['train']['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
if self.opt['train']['gan_type'] == 'crossgan': if self.opt['train']['gan_type'] == 'crossgan':
pred_g_fake = self.netD(fake_GenOut, var_L) pred_g_fake = self.netsD(fake_GenOut, var_L)
else: else:
pred_g_fake = self.netD(fake_GenOut) pred_g_fake = self.netsD(fake_GenOut)
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
pred_d_real = self.netD(var_ref).detach() pred_d_real = self.netsD(var_ref).detach()
pred_g_fake = self.netD(fake_GenOut) pred_g_fake = self.netsD(fake_GenOut)
l_g_gan = self.l_gan_w * ( l_g_gan = self.l_gan_w * (
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
@ -277,9 +277,9 @@ class ExtensibleTrainer(BaseModel):
if self.spsr_enabled and self.cri_grad_gan: if self.spsr_enabled and self.cri_grad_gan:
if self.opt['train']['gan_type'] == 'crossgan': if self.opt['train']['gan_type'] == 'crossgan':
pred_g_fake_grad = self.netD(fake_H_grad, var_L) pred_g_fake_grad = self.netsD(fake_H_grad, var_L)
else: else:
pred_g_fake_grad = self.netD(fake_H_grad) pred_g_fake_grad = self.netsD(fake_H_grad)
pred_g_fake_grad_branch = self.netD_grad(fake_H_branch) pred_g_fake_grad_branch = self.netD_grad(fake_H_branch)
if self.opt['train']['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: if self.opt['train']['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True) l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
@ -313,7 +313,7 @@ class ExtensibleTrainer(BaseModel):
# D # D
if self.l_gan_w > 0 and step >= self.G_warmup: if self.l_gan_w > 0 and step >= self.G_warmup:
for p in self.netD.parameters(): for p in self.netsD.parameters():
if p.dtype != torch.int64 and p.dtype != torch.bool: if p.dtype != torch.int64 and p.dtype != torch.bool:
p.requires_grad = True p.requires_grad = True
@ -328,9 +328,9 @@ class ExtensibleTrainer(BaseModel):
# Re-compute generator outputs with the GAN inputs. # Re-compute generator outputs with the GAN inputs.
with torch.no_grad(): with torch.no_grad():
if self.spsr_enabled: if self.spsr_enabled:
_, fake_H, _ = self.netG(var_LGAN) _, fake_H, _ = self.netsG(var_LGAN)
else: else:
_, fake_H = self.netG(var_LGAN) _, fake_H = self.netsG(var_LGAN)
fake_H = fake_H.detach() fake_H = fake_H.detach()
if _profile: if _profile:
@ -346,26 +346,26 @@ class ExtensibleTrainer(BaseModel):
if self.opt['train']['gan_type'] == 'pixgan_fea': if self.opt['train']['gan_type'] == 'pixgan_fea':
# Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better. # Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better.
disc_fea_scale = .1 disc_fea_scale = .1
_, fea_real = self.netD(var_ref, output_feature_vector=True) _, fea_real = self.netsD(var_ref, output_feature_vector=True)
actual_fea = self.netF(var_ref) actual_fea = self.netF(var_ref)
l_d_fea_real = self.cri_fea(fea_real, actual_fea) * disc_fea_scale / self.mega_batch_factor l_d_fea_real = self.cri_fea(fea_real, actual_fea) * disc_fea_scale / self.mega_batch_factor
_, fea_fake = self.netD(fake_H, output_feature_vector=True) _, fea_fake = self.netsD(fake_H, output_feature_vector=True)
actual_fea = self.netF(fake_H) actual_fea = self.netF(fake_H)
l_d_fea_fake = self.cri_fea(fea_fake, actual_fea) * disc_fea_scale / self.mega_batch_factor l_d_fea_fake = self.cri_fea(fea_fake, actual_fea) * disc_fea_scale / self.mega_batch_factor
if self.opt['train']['gan_type'] == 'crossgan': if self.opt['train']['gan_type'] == 'crossgan':
# need to forward and backward separately, since batch norm statistics differ # need to forward and backward separately, since batch norm statistics differ
# real # real
pred_d_real = self.netD(var_ref, var_L) pred_d_real = self.netsD(var_ref, var_L)
l_d_real = self.cri_gan(pred_d_real, True) l_d_real = self.cri_gan(pred_d_real, True)
l_d_real_log = l_d_real l_d_real_log = l_d_real
# fake # fake
pred_d_fake = self.netD(fake_H, var_L) pred_d_fake = self.netsD(fake_H, var_L)
l_d_fake = self.cri_gan(pred_d_fake, False) l_d_fake = self.cri_gan(pred_d_fake, False)
l_d_fake_log = l_d_fake l_d_fake_log = l_d_fake
# mismatched # mismatched
mismatched_L = torch.roll(var_L, shifts=1, dims=0) mismatched_L = torch.roll(var_L, shifts=1, dims=0)
pred_d_real_mismatched = self.netD(var_ref, mismatched_L) pred_d_real_mismatched = self.netsD(var_ref, mismatched_L)
pred_d_fake_mismatched = self.netD(fake_H, mismatched_L) pred_d_fake_mismatched = self.netsD(fake_H, mismatched_L)
l_d_mismatched = (self.cri_gan(pred_d_real_mismatched, False) + self.cri_gan(pred_d_fake_mismatched, False)) / 2 l_d_mismatched = (self.cri_gan(pred_d_real_mismatched, False) + self.cri_gan(pred_d_fake_mismatched, False)) / 2
l_d_total = (l_d_real + l_d_fake + l_d_mismatched) / 3 l_d_total = (l_d_real + l_d_fake + l_d_mismatched) / 3
@ -374,11 +374,11 @@ class ExtensibleTrainer(BaseModel):
l_d_total_scaled.backward() l_d_total_scaled.backward()
elif self.opt['train']['gan_type'] == 'gan': elif self.opt['train']['gan_type'] == 'gan':
# real # real
pred_d_real = self.netD(var_ref) pred_d_real = self.netsD(var_ref)
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
l_d_real_log = l_d_real * self.mega_batch_factor l_d_real_log = l_d_real * self.mega_batch_factor
# fake # fake
pred_d_fake = self.netD(fake_H) pred_d_fake = self.netsD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
l_d_fake_log = l_d_fake * self.mega_batch_factor l_d_fake_log = l_d_fake * self.mega_batch_factor
@ -386,7 +386,7 @@ class ExtensibleTrainer(BaseModel):
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
l_d_total_scaled.backward() l_d_total_scaled.backward()
elif 'pixgan' in self.opt['train']['gan_type']: elif 'pixgan' in self.opt['train']['gan_type']:
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() pixdisc_channels, pixdisc_output_reduction = self.netsD.module.pixgan_parameters()
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
b, _, w, h = var_ref.shape b, _, w, h = var_ref.shape
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device) real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
@ -424,12 +424,12 @@ class ExtensibleTrainer(BaseModel):
fake = fake.view(-1, 1) fake = fake.view(-1, 1)
# real # real
pred_d_real = self.netD(var_ref) pred_d_real = self.netsD(var_ref)
l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor
l_d_real_log = l_d_real * self.mega_batch_factor l_d_real_log = l_d_real * self.mega_batch_factor
l_d_real += l_d_fea_real l_d_real += l_d_fea_real
# fake # fake
pred_d_fake = self.netD(fake_H) pred_d_fake = self.netsD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor
l_d_fake_log = l_d_fake * self.mega_batch_factor l_d_fake_log = l_d_fake * self.mega_batch_factor
l_d_fake += l_d_fea_fake l_d_fake += l_d_fea_fake
@ -445,8 +445,8 @@ class ExtensibleTrainer(BaseModel):
pdf = pdf / torch.max(pdf) pdf = pdf / torch.max(pdf)
fake_disc_images.append(pdf.view(disc_output_shape)) fake_disc_images.append(pdf.view(disc_output_shape))
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
pred_d_fake = self.netD(fake_H) pred_d_fake = self.netsD(fake_H)
pred_d_real = self.netD(var_ref) pred_d_real = self.netsD(var_ref)
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
l_d_real_log = l_d_real l_d_real_log = l_d_real
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
@ -597,19 +597,19 @@ class ExtensibleTrainer(BaseModel):
return self.cri_fea(fake_fea, real_fea).item() return self.cri_fea(fake_fea, real_fea).item()
def test(self): def test(self):
self.netG.eval() self.netsG.eval()
with torch.no_grad(): with torch.no_grad():
if self.spsr_enabled: if self.spsr_enabled:
self.fake_H_branch = [] self.fake_H_branch = []
self.fake_GenOut = [] self.fake_GenOut = []
self.grad_LR = [] self.grad_LR = []
fake_H_branch, fake_GenOut, grad_LR = self.netG(self.var_L[0]) fake_H_branch, fake_GenOut, grad_LR = self.netsG(self.var_L[0])
self.fake_H_branch.append(fake_H_branch) self.fake_H_branch.append(fake_H_branch)
self.fake_GenOut.append(fake_GenOut) self.fake_GenOut.append(fake_GenOut)
self.grad_LR.append(grad_LR) self.grad_LR.append(grad_LR)
else: else:
self.fake_GenOut = [self.netG(self.var_L[0])] self.fake_GenOut = [self.netsG(self.var_L[0])]
self.netG.train() self.netsG.train()
# Fetches a summary of the log. # Fetches a summary of the log.
def get_current_log(self, step): def get_current_log(self, step):
@ -620,10 +620,10 @@ class ExtensibleTrainer(BaseModel):
return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k]) return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k])
# Some generators can do their own metric logging. # Some generators can do their own metric logging.
if hasattr(self.netG.module, "get_debug_values"): if hasattr(self.netsG.module, "get_debug_values"):
return_log.update(self.netG.module.get_debug_values(step)) return_log.update(self.netsG.module.get_debug_values(step))
if hasattr(self.netD.module, "get_debug_values"): if hasattr(self.netsD.module, "get_debug_values"):
return_log.update(self.netD.module.get_debug_values(step)) return_log.update(self.netsD.module.get_debug_values(step))
return return_log return return_log

View File

@ -0,0 +1,9 @@
def create_generator_loss(opt_loss):
pass
class GeneratorLoss:
def __init__(self, opt):
self.opt = opt
def get_loss(self, var_L, var_H, var_Gen, extras=None):

View File

@ -0,0 +1,46 @@
# Defines the expected API for a step
class SrGanGeneratorStep:
def __init__(self, opt_step, opt, netsG, netsD):
self.step_opt = opt_step
self.opt = opt
self.gen = netsG['base']
self.disc = netsD['base']
for loss in self.step_opt['losses']:
# G pixel loss
if train_opt['pixel_weight'] > 0:
l_pix_type = train_opt['pixel_criterion']
if l_pix_type == 'l1':
self.cri_pix = nn.L1Loss().to(self.device)
elif l_pix_type == 'l2':
self.cri_pix = nn.MSELoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
self.l_pix_w = train_opt['pixel_weight']
else:
logger.info('Remove pixel loss.')
self.cri_pix = None
# Returns all optimizers used in this step.
def get_optimizers(self):
pass
# Returns optimizers which are opting in for default LR scheduling.
def get_optimizers_with_default_scheduler(self):
pass
# Returns the names of the networks this step will train. Other networks will be frozen.
def get_networks_trained(self):
pass
# Performs all forward and backward passes for this step given an input state. All input states are lists or
# chunked tensors. Use grad_accum_step to derefernce these steps. Return the state with any variables the step
# exports (which may be used by subsequent steps)
def do_forward_backward(self, state, grad_accum_step):
return state
# Performs the optimizer step after all gradient accumulation is completed.
def do_step(self):
pass

View File

@ -1,6 +1,6 @@
def create_step(opt_step): def create_step(opt, opt_step, netsG, netsD):
pass pass