From ab04ca1778cbc8e16a653d2e0f18d4d33f9d7041 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 12 Aug 2020 08:45:23 -0600 Subject: [PATCH] Extensible trainer (in progress) --- codes/models/ExtensibleTrainer.py | 661 ++++++++++++++++++++++++++++++ codes/models/lr_scheduler.py | 24 ++ codes/models/steps/steps.py | 29 ++ 3 files changed, 714 insertions(+) create mode 100644 codes/models/ExtensibleTrainer.py create mode 100644 codes/models/steps/steps.py diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py new file mode 100644 index 00000000..1ce477ac --- /dev/null +++ b/codes/models/ExtensibleTrainer.py @@ -0,0 +1,661 @@ +import logging +from collections import OrderedDict +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel, DistributedDataParallel +import models.networks as networks +from models.steps.steps import create_step +import models.lr_scheduler as lr_scheduler +from models.base_model import BaseModel +from models.loss import GANLoss, FDPLLoss +from apex import amp +from data.weight_scheduler import get_scheduler_for_opt +from .archs.SPSR_arch import ImageGradient, ImageGradientNoPadding +import torch.nn.functional as F +import glob +import random + +import torchvision.utils as utils +import os + +logger = logging.getLogger('base') + + +class ExtensibleTrainer(BaseModel): + def __init__(self, opt): + super(ExtensibleTrainer, self).__init__(opt) + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + train_opt = opt['train'] + self.mega_batch_factor = 1 + + self.netG = {} + self.netD = {} + self.networks = [] + for name, net in opt['networks'].items(): + if net['type'] == 'generator': + new_net = networks.define_G(net) + self.netG[name] = new_net + elif net['type'] == 'discriminator': + new_net = networks.define_D(net) + self.netD[name] = new_net + else: + raise NotImplementedError("Can only handle generators and discriminators") + self.networks.append(new_net) + + if self.is_train: + self.mega_batch_factor = train_opt['mega_batch_factor'] + if self.mega_batch_factor is None: + self.mega_batch_factor = 1 + + # Initialize amp. + amp_nets, amp_opts = amp.initialize(self.networks, self.optimizers, opt_level=opt['amp_level'], num_losses=len(self.optimizers)) + # self.networks is stored unwrapped. It should never be used for forward() or backward() passes, instead use + # self.netG and self.netD for that. + self.networks = amp_nets + + # DataParallel + dnets = [] + for anet in amp_nets: + if opt['dist']: + dnet = DistributedDataParallel(anet, + device_ids=[torch.cuda.current_device()], + find_unused_parameters=True) + else: + dnet = DataParallel(anet) + if self.is_train: + dnet.train() + else: + dnet.eval() + dnets.append(dnet) + + # Backpush the wrapped networks into the network dicts.. + found = 0 + for dnet in dnets: + for net_dict in [self.netD, self.netG]: + for k, v in net_dict.items(): + if v == dnet: + net_dict[k] = dnet + found += 1 + assert found == len(self.networks) + + # Initialize the training steps + self.steps = [] + for step in opt['steps']: + step = create_step(step, self.netG, self.netD) + self.steps.append(step) + self.optimizers.extend(step.get_optimizers()) + + # Find the optimizers that are using the default scheduler, then build them. + def_opt = [] + for s in self.steps: + def_opt.extend(s.get_optimizers_with_default_scheduler()) + lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt) + + self.print_network() # print network + self.load() # load G and D if needed + + # Setting this to false triggers SRGAN to call the models update_model() function on the first iteration. + self.updated = True + + def feed_data(self, data): + self.var_L = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0) + self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] + input_ref = data['ref'] if 'ref' in data else data['GT'] + self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] + + def optimize_parameters(self, step): + # Some models need to make parametric adjustments per-step. Do that here. + for net in self.networks.values(): + if hasattr(net, "update_for_step"): + net.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) + + # Iterate through the steps, performing them one at a time. + state = {'lr': self.var_L, 'hr': self.var_H, 'ref': self.var_ref} + for s in self.steps: + # Only set requires_grad=True for the network being trained. + nets_to_train = s.get_networks_trained() + for name, net in self.networks.items(): + net_enabled = name in nets_to_train + for p in self.netG.parameters(): + if p.dtype != torch.int64 and p.dtype != torch.bool: + p.requires_grad = net_enabled + else: + p.requires_grad = False + + # Now do a forward and backward pass for each gradient accumulation step. + for m in range(self.mega_batch_factor): + state = s.do_forward_backward(state, m) + + # And finally perform optimization. + s.do_step() + + + + # G + for p in self.netD.parameters(): + p.requires_grad = False + if self.spsr_enabled: + for p in self.netD_grad.parameters(): + p.requires_grad = False + + self.swapout_D(step) + self.swapout_G(step) + + # Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason. + if step % self.D_update_ratio == 0 and step >= self.D_init_iters: + if self.spsr_enabled and self.branch_pretrain and step < self.branch_init_iters: + for k, v in self.netG.named_parameters(): + if v.dtype != torch.int64 and v.dtype != torch.bool: + v.requires_grad = '_branch_pretrain' in k + else: + for p in self.netG.parameters(): + if p.dtype != torch.int64 and p.dtype != torch.bool: + p.requires_grad = True + else: + for p in self.netG.parameters(): + p.requires_grad = False + + # Calculate a standard deviation for the gaussian noise to be applied to the discriminator, termed noise-theta. + if self.D_noise_final == 0: + noise_theta = 0 + else: + noise_theta = (self.D_noise_theta - self.D_noise_theta_floor) * (self.D_noise_final - min(step, self.D_noise_final)) / self.D_noise_final + self.D_noise_theta_floor + + if _profile: + print("Misc setup %f" % (time() - _t,)) + _t = time() + + if step >= self.D_init_iters: + self.optimizer_G.zero_grad() + self.fake_GenOut = [] + self.fea_GenOut = [] + self.fake_H = [] + self.spsr_grad_GenOut = [] + var_ref_skips = [] + for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix): + if self.spsr_enabled: + using_gan_img = False + # SPSR models have outputs from three different branches. + fake_H_branch, fake_GenOut, grad_LR = self.netG(var_L) + fea_GenOut = fake_GenOut + self.spsr_grad_GenOut.append(fake_H_branch) + # Get image gradients for later use. + fake_H_grad = self.get_grad_nopadding(fake_GenOut) + else: + if random.random() > self.gan_lq_img_use_prob: + fea_GenOut, fake_GenOut = self.netG(var_L) + using_gan_img = False + else: + fea_GenOut, fake_GenOut = self.netG(var_LGAN) + using_gan_img = True + + if _profile: + print("Gen forward %f" % (time() - _t,)) + _t = time() + + self.fake_GenOut.append(fake_GenOut.detach()) + self.fea_GenOut.append(fea_GenOut.detach()) + + l_g_total = 0 + if step % self.D_update_ratio == 0 and step >= self.D_init_iters: + fea_w = self.l_fea_sched.get_weight_for_step(step) + l_g_pix_log = None + l_g_fea_log = None + l_g_fdpl = None + l_g_fea_log = None + if self.cri_pix and not using_gan_img: # pixel loss + l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix) + l_g_pix_log = l_g_pix / self.l_pix_w + l_g_total += l_g_pix + if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss + if self.disjoint_data: + grad_truth = self.get_grad_nopadding(var_L) + grad_pred = F.interpolate(fake_H_grad, size=grad_truth.shape[2:], mode="nearest") + else: + grad_truth = self.get_grad_nopadding(var_H) + grad_pred = fake_H_grad + l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(grad_pred, grad_truth) + l_g_total += l_g_pix_grad + if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss + if self.disjoint_data: + grad_truth = self.get_grad_nopadding(var_L) + grad_pred = F.interpolate(fake_H_branch, size=grad_truth.shape[2:], mode="nearest") + else: + grad_truth = self.get_grad_nopadding(var_H) + grad_pred = fake_H_branch + l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(grad_pred, grad_truth) + l_g_total += l_g_pix_grad_branch + if self.fdpl_enabled and not using_gan_img: + l_g_fdpl = self.cri_fdpl(fea_GenOut, pix) + l_g_total += l_g_fdpl * self.fdpl_weight + if self.cri_fea and not using_gan_img and fea_w > 0: # feature loss + if self.lr_netF is not None: + real_fea = self.lr_netF(var_L, interpolate_factor=self.opt['scale']) + else: + real_fea = self.netF(pix).detach() + fake_fea = self.netF(fea_GenOut) + l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea) + l_g_fea_log = l_g_fea / fea_w + l_g_total += l_g_fea + + if _profile: + print("Fea forward %f" % (time() - _t,)) + _t = time() + + # Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931 + # Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is + # equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically, + # it should target this + + l_g_fix_disc = torch.zeros(1, requires_grad=False, device=self.device).squeeze() + for fixed_disc in self.fixed_disc_nets: + weight = fixed_disc.module.fdisc_weight + real_fea = fixed_disc(pix).detach() + fake_fea = fixed_disc(fea_GenOut) + l_g_fix_disc = l_g_fix_disc + weight * self.cri_fea(fake_fea, real_fea) + l_g_total += l_g_fix_disc + + + if self.l_gan_w > 0: + if self.opt['train']['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: + if self.opt['train']['gan_type'] == 'crossgan': + pred_g_fake = self.netD(fake_GenOut, var_L) + else: + pred_g_fake = self.netD(fake_GenOut) + l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) + elif self.opt['train']['gan_type'] == 'ragan': + pred_d_real = self.netD(var_ref).detach() + pred_g_fake = self.netD(fake_GenOut) + l_g_gan = self.l_gan_w * ( + self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 + l_g_gan_log = l_g_gan / self.l_gan_w + l_g_total += l_g_gan + + if self.spsr_enabled and self.cri_grad_gan: + if self.opt['train']['gan_type'] == 'crossgan': + pred_g_fake_grad = self.netD(fake_H_grad, var_L) + else: + pred_g_fake_grad = self.netD(fake_H_grad) + pred_g_fake_grad_branch = self.netD_grad(fake_H_branch) + if self.opt['train']['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: + l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True) + l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True) + elif self.opt['train']['gan_type'] == 'ragan': + pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach() + l_g_gan_grad = self.l_gan_w * ( + self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) + + self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2 + l_g_gan_grad_branch = self.l_gan_w * ( + self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad_branch), False) + + self.cri_gan(pred_g_fake_grad_branch - torch.mean(pred_g_real_grad), True)) / 2 + l_g_total += l_g_gan_grad + l_g_gan_grad_branch + + # Scale the loss down by the batch factor. + l_g_total_log = l_g_total + l_g_total = l_g_total / self.mega_batch_factor + + with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled: + l_g_total_scaled.backward() + + if _profile: + print("Gen backward %f" % (time() - _t,)) + _t = time() + + self.optimizer_G.step() + + if _profile: + print("Gen step %f" % (time() - _t,)) + _t = time() + + # D + if self.l_gan_w > 0 and step >= self.G_warmup: + for p in self.netD.parameters(): + if p.dtype != torch.int64 and p.dtype != torch.bool: + p.requires_grad = True + + noise = torch.randn_like(var_ref) * noise_theta + noise.to(self.device) + real_disc_images = [] + fake_disc_images = [] + for fake_GenOut, var_LGAN, var_L, var_H, var_ref, pix in zip(self.fake_GenOut, self.gan_img, self.var_L, self.var_H, self.var_ref, self.pix): + if random.random() > self.gan_lq_img_use_prob: + fake_H = fake_GenOut.clone().detach().requires_grad_(False) + else: + # Re-compute generator outputs with the GAN inputs. + with torch.no_grad(): + if self.spsr_enabled: + _, fake_H, _ = self.netG(var_LGAN) + else: + _, fake_H = self.netG(var_LGAN) + fake_H = fake_H.detach() + + if _profile: + print("Gen forward for disc %f" % (time() - _t,)) + _t = time() + + # Apply noise to the inputs to slow discriminator convergence. + var_ref = var_ref + noise + fake_H = fake_H + noise + l_d_fea_real = 0 + l_d_fea_fake = 0 + self.optimizer_D.zero_grad() + if self.opt['train']['gan_type'] == 'pixgan_fea': + # Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better. + disc_fea_scale = .1 + _, fea_real = self.netD(var_ref, output_feature_vector=True) + actual_fea = self.netF(var_ref) + l_d_fea_real = self.cri_fea(fea_real, actual_fea) * disc_fea_scale / self.mega_batch_factor + _, fea_fake = self.netD(fake_H, output_feature_vector=True) + actual_fea = self.netF(fake_H) + l_d_fea_fake = self.cri_fea(fea_fake, actual_fea) * disc_fea_scale / self.mega_batch_factor + if self.opt['train']['gan_type'] == 'crossgan': + # need to forward and backward separately, since batch norm statistics differ + # real + pred_d_real = self.netD(var_ref, var_L) + l_d_real = self.cri_gan(pred_d_real, True) + l_d_real_log = l_d_real + # fake + pred_d_fake = self.netD(fake_H, var_L) + l_d_fake = self.cri_gan(pred_d_fake, False) + l_d_fake_log = l_d_fake + # mismatched + mismatched_L = torch.roll(var_L, shifts=1, dims=0) + pred_d_real_mismatched = self.netD(var_ref, mismatched_L) + pred_d_fake_mismatched = self.netD(fake_H, mismatched_L) + l_d_mismatched = (self.cri_gan(pred_d_real_mismatched, False) + self.cri_gan(pred_d_fake_mismatched, False)) / 2 + + l_d_total = (l_d_real + l_d_fake + l_d_mismatched) / 3 + l_d_total = l_d_total / self.mega_batch_factor + with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: + l_d_total_scaled.backward() + elif self.opt['train']['gan_type'] == 'gan': + # real + pred_d_real = self.netD(var_ref) + l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor + l_d_real_log = l_d_real * self.mega_batch_factor + # fake + pred_d_fake = self.netD(fake_H) + l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor + l_d_fake_log = l_d_fake * self.mega_batch_factor + + l_d_total = (l_d_real + l_d_fake) / 2 + with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: + l_d_total_scaled.backward() + elif 'pixgan' in self.opt['train']['gan_type']: + pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() + disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) + b, _, w, h = var_ref.shape + real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device) + fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device) + if not self.disjoint_data: + # randomly determine portions of the image to swap to keep the discriminator honest. + SWAP_MAX_DIM = w // 4 + SWAP_MIN_DIM = 16 + assert SWAP_MAX_DIM > 0 + if random.random() > .5: # Make this only happen half the time. Earlier experiments had it happen + # more often and the model was "cheating" by using the presence of + # easily discriminated fake swaps to count the entire generated image + # as fake. + random_swap_count = random.randint(0, 4) + for i in range(random_swap_count): + # Make the swap across fake_H and var_ref + swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM) + swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM) + if swap_x + swap_w > w: + swap_w = w - swap_x + if swap_y + swap_h > h: + swap_h = h - swap_y + t = fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() + fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] + var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t + real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0 + fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0 + + # Interpolate down to the dimensionality that the discriminator uses. + real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear", align_corners=False) + fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear", align_corners=False) + + # We're also assuming that this is exactly how the flattened discriminator output is generated. + real = real.view(-1, 1) + fake = fake.view(-1, 1) + + # real + pred_d_real = self.netD(var_ref) + l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor + l_d_real_log = l_d_real * self.mega_batch_factor + l_d_real += l_d_fea_real + # fake + pred_d_fake = self.netD(fake_H) + l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor + l_d_fake_log = l_d_fake * self.mega_batch_factor + l_d_fake += l_d_fea_fake + + l_d_total = (l_d_real + l_d_fake) / 2 + with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: + l_d_total_scaled.backward() + + pdr = pred_d_real.detach() + torch.abs(torch.min(pred_d_real)) + pdr = pdr / torch.max(pdr) + real_disc_images.append(pdr.view(disc_output_shape)) + pdf = pred_d_fake.detach() + torch.abs(torch.min(pred_d_fake)) + pdf = pdf / torch.max(pdf) + fake_disc_images.append(pdf.view(disc_output_shape)) + elif self.opt['train']['gan_type'] == 'ragan': + pred_d_fake = self.netD(fake_H) + pred_d_real = self.netD(var_ref) + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) + l_d_real_log = l_d_real + l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) + l_d_fake_log = l_d_fake + l_d_total = (l_d_real + l_d_fake) / 2 + l_d_total /= self.mega_batch_factor + with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: + l_d_total_scaled.backward() + var_ref_skips.append(var_ref.detach()) + self.fake_H.append(fake_H.detach()) + self.optimizer_D.step() + + if _profile: + print("Disc step %f" % (time() - _t,)) + _t = time() + + # D_grad. + if self.spsr_enabled and self.cri_grad_gan and step >= self.G_warmup: + for p in self.netD_grad.parameters(): + p.requires_grad = True + self.optimizer_D_grad.zero_grad() + for var_ref, fake_H, fake_H_grad_branch in zip(var_ref_skips, self.fake_H, self.spsr_grad_GenOut): + fake_H_grad = self.get_grad_nopadding(fake_H).detach() + var_ref_grad = self.get_grad_nopadding(var_ref) + pred_d_real_grad = self.netD_grad(var_ref_grad) + pred_d_fake_grad = self.netD_grad(fake_H_grad) # Tensor already detached above. + # var_ref and fake_H already has noise added to it. We **must** add noise to fake_H_grad_branch too. + fake_H_grad_branch = fake_H_grad_branch.detach() + noise + pred_d_fake_grad_branch = self.netD_grad(fake_H_grad_branch) + if self.opt['train']['gan_type'] == 'gan': + l_d_real_grad = self.cri_gan(pred_d_real_grad, True) + l_d_fake_grad = (self.cri_gan(pred_d_fake_grad, False) + self.cri_gan(pred_d_fake_grad_branch, False)) / 2 + elif self.opt['train']['gan_type'] == 'crossgan': + assert False + elif self.opt['train']['gan_type'] == 'pixgan': + real = torch.ones_like(pred_d_real_grad) + fake = torch.zeros_like(pred_d_fake_grad) + l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real) + l_d_fake_grad = (self.cri_grad_gan(pred_d_fake_grad, fake) + \ + self.cri_grad_gan(pred_d_fake_grad_branch, fake)) / 2 + elif self.opt['train']['gan_type'] == 'ragan': + l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True) + l_d_fake_grad = (self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) + \ + self.cri_grad_gan(pred_d_fake_grad_branch - torch.mean(pred_d_real_grad), False)) / 2 + + l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2 + l_d_total_grad /= self.mega_batch_factor + with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled: + l_d_total_grad_scaled.backward() + self.optimizer_D_grad.step() + + + # Log sample images from first microbatch. + if step % self.img_debug_steps == 0: + sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp") + os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "gen_fea"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True) + if self.spsr_enabled: + os.makedirs(os.path.join(sample_save_path, "gen_grad"), exist_ok=True) + + for i in range(self.mega_batch_factor): + utils.save_image(self.var_H[i].cpu(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i))) + utils.save_image(self.var_L[i].cpu(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i))) + utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i))) + utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) + utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i))) + if self.spsr_enabled: + utils.save_image(self.spsr_grad_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_grad", "%05i_%02i.png" % (step, i))) + if self.l_gan_w > 0 and step >= self.G_warmup and 'pixgan' in self.opt['train']['gan_type']: + utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i))) + utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i))) + utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i))) + utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i))) + + # Log metrics + if step % self.D_update_ratio == 0 and step >= self.D_init_iters: + if self.cri_pix and l_g_pix_log is not None: + self.add_log_entry('l_g_pix', l_g_pix_log.detach().item()) + if self.fdpl_enabled and l_g_fdpl is not None: + self.add_log_entry('l_g_fdpl', l_g_fdpl.detach().item()) + if self.cri_fea and l_g_fea_log is not None: + self.add_log_entry('feature_weight', fea_w) + self.add_log_entry('l_g_fea', l_g_fea_log.detach().item()) + self.add_log_entry('l_g_fix_disc', l_g_fix_disc.detach().item()) + if self.l_gan_w > 0: + self.add_log_entry('l_g_gan', l_g_gan_log.detach().item()) + self.add_log_entry('l_g_total', l_g_total_log.detach().item()) + if self.opt['train']['gan_type'] == 'pixgan_fea': + self.add_log_entry('l_d_fea_fake', l_d_fea_fake.detach().item() * self.mega_batch_factor) + self.add_log_entry('l_d_fea_real', l_d_fea_real.detach().item() * self.mega_batch_factor) + self.add_log_entry('l_d_fake_total', l_d_fake.detach().item() * self.mega_batch_factor) + self.add_log_entry('l_d_real_total', l_d_real.detach().item() * self.mega_batch_factor) + if self.opt['train']['gan_type'] == 'crossgan': + self.add_log_entry('l_d_mismatched', l_d_mismatched.detach().item()) + if self.spsr_enabled: + if self.cri_pix_grad: + self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad.detach().item()) + if self.cri_pix_branch: + self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.detach().item()) + if self.cri_grad_gan: + self.add_log_entry('l_g_gan_grad', l_g_gan_grad.detach().item()) + self.add_log_entry('l_g_gan_grad_branch', l_g_gan_grad_branch.detach().item()) + if self.l_gan_w > 0 and step >= self.G_warmup: + self.add_log_entry('l_d_real', l_d_real_log.detach().item()) + self.add_log_entry('l_d_fake', l_d_fake_log.detach().item()) + self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) + self.add_log_entry('D_diff', torch.mean(pred_d_fake.detach()) - torch.mean(pred_d_real.detach())) + if self.spsr_enabled: + self.add_log_entry('l_d_real_grad', l_d_real_grad.detach().item()) + self.add_log_entry('l_d_fake_grad', l_d_fake_grad.detach().item()) + self.add_log_entry('D_fake_grad', torch.mean(pred_d_fake_grad.detach())) + self.add_log_entry('D_diff_grad', torch.mean(pred_d_fake_grad.detach()) - torch.mean(pred_d_real_grad.detach())) + + # Log learning rates. + for i, pg in enumerate(self.optimizer_G.param_groups): + self.add_log_entry('gen_lr_%i' % (i,), pg['lr']) + for i, pg in enumerate(self.optimizer_D.param_groups): + self.add_log_entry('disc_lr_%i' % (i,), pg['lr']) + + if step % self.corruptor_swapout_steps == 0 and step > 0: + self.load_random_corruptor() + + # Allows the log to serve as an easy-to-use rotating buffer. + def add_log_entry(self, key, value): + key_it = "%s_it" % (key,) + log_rotating_buffer_size = 50 + if key not in self.log_dict.keys(): + self.log_dict[key] = [] + self.log_dict[key_it] = 0 + if len(self.log_dict[key]) < log_rotating_buffer_size: + self.log_dict[key].append(value) + else: + self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value + self.log_dict[key_it] += 1 + + def compute_fea_loss(self, real, fake): + with torch.no_grad(): + real = real.unsqueeze(dim=0).to(self.device) + fake = fake.unsqueeze(dim=0).to(self.device) + real_fea = self.netF(real).detach() + fake_fea = self.netF(fake) + return self.cri_fea(fake_fea, real_fea).item() + + def test(self): + self.netG.eval() + with torch.no_grad(): + if self.spsr_enabled: + self.fake_H_branch = [] + self.fake_GenOut = [] + self.grad_LR = [] + fake_H_branch, fake_GenOut, grad_LR = self.netG(self.var_L[0]) + self.fake_H_branch.append(fake_H_branch) + self.fake_GenOut.append(fake_GenOut) + self.grad_LR.append(grad_LR) + else: + self.fake_GenOut = [self.netG(self.var_L[0])] + self.netG.train() + + # Fetches a summary of the log. + def get_current_log(self, step): + return_log = {} + for k in self.log_dict.keys(): + if not isinstance(self.log_dict[k], list): + continue + return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k]) + + # Some generators can do their own metric logging. + if hasattr(self.netG.module, "get_debug_values"): + return_log.update(self.netG.module.get_debug_values(step)) + if hasattr(self.netD.module, "get_debug_values"): + return_log.update(self.netD.module.get_debug_values(step)) + + return return_log + + def get_current_visuals(self, need_GT=True): + out_dict = OrderedDict() + out_dict['LQ'] = self.var_L[0].detach().float().cpu() + gen_batch = self.fake_GenOut[0] + if isinstance(gen_batch, tuple): + gen_batch = gen_batch[0] + out_dict['rlt'] = gen_batch.detach().float().cpu() + if need_GT: + out_dict['GT'] = self.var_H[0].detach().float().cpu() + if self.spsr_enabled: + out_dict['SR_branch'] = self.fake_H_branch[0].float().cpu() + out_dict['LR_grad'] = self.grad_LR[0].float().cpu() + return out_dict + + def print_network(self): + for name, net in self.networks.items(): + s, n = self.get_network_description(net) + net_struc_str = '{}'.format(net.__class__.__name__) + if self.rank <= 0: + logger.info('Network ' + name + ' structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + def load(self): + for name, net in self.networks.items(): + load_path = opt['path'][name] + if load_path is not None: + logger.info('Loading model for %s: [%s]' % (name, load_path)) + self.load_network(load_path, net) + + def save(self, iter_step): + for name, net in self.networks.items(): + self.save_network(net, name, iter_step) diff --git a/codes/models/lr_scheduler.py b/codes/models/lr_scheduler.py index da90c223..f299b87f 100644 --- a/codes/models/lr_scheduler.py +++ b/codes/models/lr_scheduler.py @@ -5,6 +5,30 @@ import torch from torch.optim.lr_scheduler import _LRScheduler +def get_scheduler_for_name(name, optimizers, scheduler_opt): + schedulers = [] + for o in optimizers: + if name == 'MultiStepLR': + sched = MultiStepLR_Restart(o, scheduler_opt['gen_lr_steps'], + restarts=scheduler_opt['restarts'], + weights=scheduler_opt['restart_weights'], + gamma=scheduler_opt['lr_gamma'], + clear_state=scheduler_opt['clear_state'], + force_lr=scheduler_opt['force_lr']) + elif name == 'ProgressiveMultiStepLR': + sched = ProgressiveMultiStepLR(o, scheduler_opt['gen_lr_steps'], + scheduler_opt['progressive_starts'], + scheduler_opt['lr_gamma']) + elif name == 'CosineAnnealingLR_Restart': + sched = CosineAnnealingLR_Restart( + o, scheduler_opt['T_period'], eta_min=scheduler_opt['eta_min'], + restarts=scheduler_opt['restarts'], weights=scheduler_opt['restart_weights']) + else: + raise NotImplementedError('Scheduler not available') + schedulers.append(sched) + return schedulers + + # This scheduler is specifically designed to modulate the learning rate of several different param groups configured # by a generator or discriminator that slowly adds new stages one at a time, e.g. like progressive growing of GANs. class ProgressiveMultiStepLR(_LRScheduler): diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py new file mode 100644 index 00000000..b5b74431 --- /dev/null +++ b/codes/models/steps/steps.py @@ -0,0 +1,29 @@ + + +def create_step(opt_step): + pass + + +# Defines the expected API for a step +class base_step: + # Returns all optimizers used in this step. + def get_optimizers(self): + pass + + # Returns optimizers which are opting in for default LR scheduling. + def get_optimizers_with_default_scheduler(self): + pass + + # Returns the names of the networks this step will train. Other networks will be frozen. + def get_networks_trained(self): + pass + + # Performs all forward and backward passes for this step given an input state. All input states are lists or + # chunked tensors. Use grad_accum_step to derefernce these steps. Return the state with any variables the step + # exports (which may be used by subsequent steps) + def do_forward_backward(self, state, grad_accum_step): + return state + + # Performs the optimizer step after all gradient accumulation is completed. + def do_step(self): + pass \ No newline at end of file