From 74cdaa2226ac9f380da0306cc3e84000462ee08d Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 18 Aug 2020 08:49:32 -0600 Subject: [PATCH] Some work on extensible trainer --- codes/models/ExtensibleTrainer.py | 88 +++++++++---------- codes/models/steps/losses/generator_losses.py | 9 ++ codes/models/steps/srgan_generator_step.py | 46 ++++++++++ codes/models/steps/steps.py | 2 +- 4 files changed, 100 insertions(+), 45 deletions(-) create mode 100644 codes/models/steps/losses/generator_losses.py create mode 100644 codes/models/steps/srgan_generator_step.py diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 1ce477ac..fef8c85a 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -31,16 +31,16 @@ class ExtensibleTrainer(BaseModel): train_opt = opt['train'] self.mega_batch_factor = 1 - self.netG = {} - self.netD = {} + self.netsG = {} + self.netsD = {} self.networks = [] for name, net in opt['networks'].items(): if net['type'] == 'generator': new_net = networks.define_G(net) - self.netG[name] = new_net + self.netsG[name] = new_net elif net['type'] == 'discriminator': new_net = networks.define_D(net) - self.netD[name] = new_net + self.netsD[name] = new_net else: raise NotImplementedError("Can only handle generators and discriminators") self.networks.append(new_net) @@ -74,7 +74,7 @@ class ExtensibleTrainer(BaseModel): # Backpush the wrapped networks into the network dicts.. found = 0 for dnet in dnets: - for net_dict in [self.netD, self.netG]: + for net_dict in [self.netsD, self.netsG]: for k, v in net_dict.items(): if v == dnet: net_dict[k] = dnet @@ -84,7 +84,7 @@ class ExtensibleTrainer(BaseModel): # Initialize the training steps self.steps = [] for step in opt['steps']: - step = create_step(step, self.netG, self.netD) + step = create_step(step, self.netsG, self.netsD) self.steps.append(step) self.optimizers.extend(step.get_optimizers()) @@ -119,7 +119,7 @@ class ExtensibleTrainer(BaseModel): nets_to_train = s.get_networks_trained() for name, net in self.networks.items(): net_enabled = name in nets_to_train - for p in self.netG.parameters(): + for p in self.netsG.parameters(): if p.dtype != torch.int64 and p.dtype != torch.bool: p.requires_grad = net_enabled else: @@ -135,7 +135,7 @@ class ExtensibleTrainer(BaseModel): # G - for p in self.netD.parameters(): + for p in self.netsD.parameters(): p.requires_grad = False if self.spsr_enabled: for p in self.netD_grad.parameters(): @@ -147,15 +147,15 @@ class ExtensibleTrainer(BaseModel): # Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason. if step % self.D_update_ratio == 0 and step >= self.D_init_iters: if self.spsr_enabled and self.branch_pretrain and step < self.branch_init_iters: - for k, v in self.netG.named_parameters(): + for k, v in self.netsG.named_parameters(): if v.dtype != torch.int64 and v.dtype != torch.bool: v.requires_grad = '_branch_pretrain' in k else: - for p in self.netG.parameters(): + for p in self.netsG.parameters(): if p.dtype != torch.int64 and p.dtype != torch.bool: p.requires_grad = True else: - for p in self.netG.parameters(): + for p in self.netsG.parameters(): p.requires_grad = False # Calculate a standard deviation for the gaussian noise to be applied to the discriminator, termed noise-theta. @@ -179,17 +179,17 @@ class ExtensibleTrainer(BaseModel): if self.spsr_enabled: using_gan_img = False # SPSR models have outputs from three different branches. - fake_H_branch, fake_GenOut, grad_LR = self.netG(var_L) + fake_H_branch, fake_GenOut, grad_LR = self.netsG(var_L) fea_GenOut = fake_GenOut self.spsr_grad_GenOut.append(fake_H_branch) # Get image gradients for later use. fake_H_grad = self.get_grad_nopadding(fake_GenOut) else: if random.random() > self.gan_lq_img_use_prob: - fea_GenOut, fake_GenOut = self.netG(var_L) + fea_GenOut, fake_GenOut = self.netsG(var_L) using_gan_img = False else: - fea_GenOut, fake_GenOut = self.netG(var_LGAN) + fea_GenOut, fake_GenOut = self.netsG(var_LGAN) using_gan_img = True if _profile: @@ -262,13 +262,13 @@ class ExtensibleTrainer(BaseModel): if self.l_gan_w > 0: if self.opt['train']['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: if self.opt['train']['gan_type'] == 'crossgan': - pred_g_fake = self.netD(fake_GenOut, var_L) + pred_g_fake = self.netsD(fake_GenOut, var_L) else: - pred_g_fake = self.netD(fake_GenOut) + pred_g_fake = self.netsD(fake_GenOut) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': - pred_d_real = self.netD(var_ref).detach() - pred_g_fake = self.netD(fake_GenOut) + pred_d_real = self.netsD(var_ref).detach() + pred_g_fake = self.netsD(fake_GenOut) l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 @@ -277,9 +277,9 @@ class ExtensibleTrainer(BaseModel): if self.spsr_enabled and self.cri_grad_gan: if self.opt['train']['gan_type'] == 'crossgan': - pred_g_fake_grad = self.netD(fake_H_grad, var_L) + pred_g_fake_grad = self.netsD(fake_H_grad, var_L) else: - pred_g_fake_grad = self.netD(fake_H_grad) + pred_g_fake_grad = self.netsD(fake_H_grad) pred_g_fake_grad_branch = self.netD_grad(fake_H_branch) if self.opt['train']['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True) @@ -313,7 +313,7 @@ class ExtensibleTrainer(BaseModel): # D if self.l_gan_w > 0 and step >= self.G_warmup: - for p in self.netD.parameters(): + for p in self.netsD.parameters(): if p.dtype != torch.int64 and p.dtype != torch.bool: p.requires_grad = True @@ -328,9 +328,9 @@ class ExtensibleTrainer(BaseModel): # Re-compute generator outputs with the GAN inputs. with torch.no_grad(): if self.spsr_enabled: - _, fake_H, _ = self.netG(var_LGAN) + _, fake_H, _ = self.netsG(var_LGAN) else: - _, fake_H = self.netG(var_LGAN) + _, fake_H = self.netsG(var_LGAN) fake_H = fake_H.detach() if _profile: @@ -346,26 +346,26 @@ class ExtensibleTrainer(BaseModel): if self.opt['train']['gan_type'] == 'pixgan_fea': # Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better. disc_fea_scale = .1 - _, fea_real = self.netD(var_ref, output_feature_vector=True) + _, fea_real = self.netsD(var_ref, output_feature_vector=True) actual_fea = self.netF(var_ref) l_d_fea_real = self.cri_fea(fea_real, actual_fea) * disc_fea_scale / self.mega_batch_factor - _, fea_fake = self.netD(fake_H, output_feature_vector=True) + _, fea_fake = self.netsD(fake_H, output_feature_vector=True) actual_fea = self.netF(fake_H) l_d_fea_fake = self.cri_fea(fea_fake, actual_fea) * disc_fea_scale / self.mega_batch_factor if self.opt['train']['gan_type'] == 'crossgan': # need to forward and backward separately, since batch norm statistics differ # real - pred_d_real = self.netD(var_ref, var_L) + pred_d_real = self.netsD(var_ref, var_L) l_d_real = self.cri_gan(pred_d_real, True) l_d_real_log = l_d_real # fake - pred_d_fake = self.netD(fake_H, var_L) + pred_d_fake = self.netsD(fake_H, var_L) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_fake_log = l_d_fake # mismatched mismatched_L = torch.roll(var_L, shifts=1, dims=0) - pred_d_real_mismatched = self.netD(var_ref, mismatched_L) - pred_d_fake_mismatched = self.netD(fake_H, mismatched_L) + pred_d_real_mismatched = self.netsD(var_ref, mismatched_L) + pred_d_fake_mismatched = self.netsD(fake_H, mismatched_L) l_d_mismatched = (self.cri_gan(pred_d_real_mismatched, False) + self.cri_gan(pred_d_fake_mismatched, False)) / 2 l_d_total = (l_d_real + l_d_fake + l_d_mismatched) / 3 @@ -374,11 +374,11 @@ class ExtensibleTrainer(BaseModel): l_d_total_scaled.backward() elif self.opt['train']['gan_type'] == 'gan': # real - pred_d_real = self.netD(var_ref) + pred_d_real = self.netsD(var_ref) l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor l_d_real_log = l_d_real * self.mega_batch_factor # fake - pred_d_fake = self.netD(fake_H) + pred_d_fake = self.netsD(fake_H) l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor l_d_fake_log = l_d_fake * self.mega_batch_factor @@ -386,7 +386,7 @@ class ExtensibleTrainer(BaseModel): with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: l_d_total_scaled.backward() elif 'pixgan' in self.opt['train']['gan_type']: - pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() + pixdisc_channels, pixdisc_output_reduction = self.netsD.module.pixgan_parameters() disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) b, _, w, h = var_ref.shape real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device) @@ -424,12 +424,12 @@ class ExtensibleTrainer(BaseModel): fake = fake.view(-1, 1) # real - pred_d_real = self.netD(var_ref) + pred_d_real = self.netsD(var_ref) l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor l_d_real_log = l_d_real * self.mega_batch_factor l_d_real += l_d_fea_real # fake - pred_d_fake = self.netD(fake_H) + pred_d_fake = self.netsD(fake_H) l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor l_d_fake_log = l_d_fake * self.mega_batch_factor l_d_fake += l_d_fea_fake @@ -445,8 +445,8 @@ class ExtensibleTrainer(BaseModel): pdf = pdf / torch.max(pdf) fake_disc_images.append(pdf.view(disc_output_shape)) elif self.opt['train']['gan_type'] == 'ragan': - pred_d_fake = self.netD(fake_H) - pred_d_real = self.netD(var_ref) + pred_d_fake = self.netsD(fake_H) + pred_d_real = self.netsD(var_ref) l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_real_log = l_d_real l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) @@ -597,19 +597,19 @@ class ExtensibleTrainer(BaseModel): return self.cri_fea(fake_fea, real_fea).item() def test(self): - self.netG.eval() + self.netsG.eval() with torch.no_grad(): if self.spsr_enabled: self.fake_H_branch = [] self.fake_GenOut = [] self.grad_LR = [] - fake_H_branch, fake_GenOut, grad_LR = self.netG(self.var_L[0]) + fake_H_branch, fake_GenOut, grad_LR = self.netsG(self.var_L[0]) self.fake_H_branch.append(fake_H_branch) self.fake_GenOut.append(fake_GenOut) self.grad_LR.append(grad_LR) else: - self.fake_GenOut = [self.netG(self.var_L[0])] - self.netG.train() + self.fake_GenOut = [self.netsG(self.var_L[0])] + self.netsG.train() # Fetches a summary of the log. def get_current_log(self, step): @@ -620,10 +620,10 @@ class ExtensibleTrainer(BaseModel): return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k]) # Some generators can do their own metric logging. - if hasattr(self.netG.module, "get_debug_values"): - return_log.update(self.netG.module.get_debug_values(step)) - if hasattr(self.netD.module, "get_debug_values"): - return_log.update(self.netD.module.get_debug_values(step)) + if hasattr(self.netsG.module, "get_debug_values"): + return_log.update(self.netsG.module.get_debug_values(step)) + if hasattr(self.netsD.module, "get_debug_values"): + return_log.update(self.netsD.module.get_debug_values(step)) return return_log diff --git a/codes/models/steps/losses/generator_losses.py b/codes/models/steps/losses/generator_losses.py new file mode 100644 index 00000000..5ae088f7 --- /dev/null +++ b/codes/models/steps/losses/generator_losses.py @@ -0,0 +1,9 @@ +def create_generator_loss(opt_loss): + pass + + +class GeneratorLoss: + def __init__(self, opt): + self.opt = opt + + def get_loss(self, var_L, var_H, var_Gen, extras=None): \ No newline at end of file diff --git a/codes/models/steps/srgan_generator_step.py b/codes/models/steps/srgan_generator_step.py new file mode 100644 index 00000000..4d7b58ca --- /dev/null +++ b/codes/models/steps/srgan_generator_step.py @@ -0,0 +1,46 @@ +# Defines the expected API for a step +class SrGanGeneratorStep: + + def __init__(self, opt_step, opt, netsG, netsD): + self.step_opt = opt_step + self.opt = opt + self.gen = netsG['base'] + self.disc = netsD['base'] + for loss in self.step_opt['losses']: + + # G pixel loss + if train_opt['pixel_weight'] > 0: + l_pix_type = train_opt['pixel_criterion'] + if l_pix_type == 'l1': + self.cri_pix = nn.L1Loss().to(self.device) + elif l_pix_type == 'l2': + self.cri_pix = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) + self.l_pix_w = train_opt['pixel_weight'] + else: + logger.info('Remove pixel loss.') + self.cri_pix = None + + + # Returns all optimizers used in this step. + def get_optimizers(self): + pass + + # Returns optimizers which are opting in for default LR scheduling. + def get_optimizers_with_default_scheduler(self): + pass + + # Returns the names of the networks this step will train. Other networks will be frozen. + def get_networks_trained(self): + pass + + # Performs all forward and backward passes for this step given an input state. All input states are lists or + # chunked tensors. Use grad_accum_step to derefernce these steps. Return the state with any variables the step + # exports (which may be used by subsequent steps) + def do_forward_backward(self, state, grad_accum_step): + return state + + # Performs the optimizer step after all gradient accumulation is completed. + def do_step(self): + pass \ No newline at end of file diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index b5b74431..bcd6f2a2 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -1,6 +1,6 @@ -def create_step(opt_step): +def create_step(opt, opt_step, netsG, netsD): pass