forked from mrq/DL-Art-School
bba283776c
attention_norm has some parameters which are not used to compute grad, which is causing failures in the distributed case.
700 lines
36 KiB
Python
700 lines
36 KiB
Python
import logging
|
|
from collections import OrderedDict
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
|
import models.networks as networks
|
|
import models.lr_scheduler as lr_scheduler
|
|
from models.base_model import BaseModel
|
|
from models.loss import GANLoss
|
|
from apex import amp
|
|
import torch.nn.functional as F
|
|
import glob
|
|
import random
|
|
|
|
import torchvision.utils as utils
|
|
import os
|
|
|
|
logger = logging.getLogger('base')
|
|
|
|
|
|
class SRGANModel(BaseModel):
|
|
def __init__(self, opt):
|
|
super(SRGANModel, self).__init__(opt)
|
|
if opt['dist']:
|
|
self.rank = torch.distributed.get_rank()
|
|
else:
|
|
self.rank = -1 # non dist training
|
|
train_opt = opt['train']
|
|
|
|
# define networks and load pretrained models
|
|
self.netG = networks.define_G(opt).to(self.device)
|
|
if self.is_train:
|
|
self.netD = networks.define_D(opt).to(self.device)
|
|
|
|
if 'network_C' in opt.keys():
|
|
self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
|
|
# The corruptor net is fixed. Lock 'her down.
|
|
self.netC.eval()
|
|
for p in self.netC.parameters():
|
|
p.requires_grad = True
|
|
else:
|
|
self.netC = None
|
|
|
|
# define losses, optimizer and scheduler
|
|
if self.is_train:
|
|
self.mega_batch_factor = train_opt['mega_batch_factor']
|
|
if self.mega_batch_factor is None:
|
|
self.mega_batch_factor = 1
|
|
# G pixel loss
|
|
if train_opt['pixel_weight'] > 0:
|
|
l_pix_type = train_opt['pixel_criterion']
|
|
if l_pix_type == 'l1':
|
|
self.cri_pix = nn.L1Loss().to(self.device)
|
|
elif l_pix_type == 'l2':
|
|
self.cri_pix = nn.MSELoss().to(self.device)
|
|
else:
|
|
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
|
self.l_pix_w = train_opt['pixel_weight']
|
|
else:
|
|
logger.info('Remove pixel loss.')
|
|
self.cri_pix = None
|
|
|
|
# G feature loss
|
|
if train_opt['feature_weight'] > 0:
|
|
l_fea_type = train_opt['feature_criterion']
|
|
if l_fea_type == 'l1':
|
|
self.cri_fea = nn.L1Loss().to(self.device)
|
|
elif l_fea_type == 'l2':
|
|
self.cri_fea = nn.MSELoss().to(self.device)
|
|
else:
|
|
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
|
self.l_fea_w = train_opt['feature_weight']
|
|
self.l_fea_w_start = train_opt['feature_weight']
|
|
self.l_fea_w_decay_start = train_opt['feature_weight_decay_start']
|
|
self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps']
|
|
self.l_fea_w_minimum = train_opt['feature_weight_minimum']
|
|
if self.l_fea_w_decay_start:
|
|
self.l_fea_w_decay_step_size = (self.l_fea_w - self.l_fea_w_minimum) / (self.l_fea_w_decay_steps)
|
|
else:
|
|
logger.info('Remove feature loss.')
|
|
self.cri_fea = None
|
|
if self.cri_fea: # load VGG perceptual loss
|
|
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
|
|
if opt['dist']:
|
|
pass # do not need to use DistributedDataParallel for netF
|
|
else:
|
|
self.netF = DataParallel(self.netF)
|
|
|
|
# GD gan loss
|
|
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
|
self.l_gan_w = train_opt['gan_weight']
|
|
# D_update_ratio and D_init_iters
|
|
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
|
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
|
self.G_warmup = train_opt['G_warmup'] if train_opt['G_warmup'] else 0
|
|
self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0
|
|
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
|
|
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
|
|
self.corruptor_swapout_steps = train_opt['corruptor_swapout_steps'] if train_opt['corruptor_swapout_steps'] else 500
|
|
self.corruptor_usage_prob = train_opt['corruptor_usage_probability'] if train_opt['corruptor_usage_probability'] else .5
|
|
|
|
# optimizers
|
|
# G
|
|
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
|
optim_params = []
|
|
if train_opt['lr_scheme'] == 'ProgressiveMultiStepLR':
|
|
optim_params = self.netG.get_param_groups()
|
|
else:
|
|
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
|
if v.requires_grad:
|
|
optim_params.append(v)
|
|
else:
|
|
if self.rank <= 0:
|
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
|
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
|
weight_decay=wd_G,
|
|
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
|
|
self.optimizers.append(self.optimizer_G)
|
|
optim_params = []
|
|
for k, v in self.netD.named_parameters(): # can optimize for a part of the model
|
|
if v.requires_grad:
|
|
optim_params.append(v)
|
|
else:
|
|
if self.rank <= 0:
|
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
|
# D
|
|
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
|
self.optimizer_D = torch.optim.Adam(optim_params, lr=train_opt['lr_D'],
|
|
weight_decay=wd_D,
|
|
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
|
|
self.optimizers.append(self.optimizer_D)
|
|
|
|
# AMP
|
|
[self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \
|
|
amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3)
|
|
|
|
# DataParallel
|
|
if opt['dist']:
|
|
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()],
|
|
find_unused_parameters=True)
|
|
else:
|
|
self.netG = DataParallel(self.netG)
|
|
if self.is_train:
|
|
if opt['dist']:
|
|
self.netD = DistributedDataParallel(self.netD,
|
|
device_ids=[torch.cuda.current_device()],
|
|
find_unused_parameters=True)
|
|
else:
|
|
self.netD = DataParallel(self.netD)
|
|
self.netG.train()
|
|
self.netD.train()
|
|
|
|
# schedulers
|
|
if train_opt['lr_scheme'] == 'MultiStepLR':
|
|
for optimizer in self.optimizers:
|
|
self.schedulers.append(
|
|
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
|
|
restarts=train_opt['restarts'],
|
|
weights=train_opt['restart_weights'],
|
|
gamma=train_opt['lr_gamma'],
|
|
clear_state=train_opt['clear_state'],
|
|
force_lr=train_opt['force_lr']))
|
|
elif train_opt['lr_scheme'] == 'ProgressiveMultiStepLR':
|
|
# Only supported when there are two optimizers: G and D.
|
|
assert len(self.optimizers) == 2
|
|
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_G, train_opt['gen_lr_steps'],
|
|
self.netG.module.get_progressive_starts(),
|
|
train_opt['lr_gamma']))
|
|
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D, train_opt['disc_lr_steps'],
|
|
[0],
|
|
train_opt['lr_gamma']))
|
|
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
|
for optimizer in self.optimizers:
|
|
self.schedulers.append(
|
|
lr_scheduler.CosineAnnealingLR_Restart(
|
|
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
|
|
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
|
|
else:
|
|
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
|
|
|
self.log_dict = OrderedDict()
|
|
|
|
# Swapout params
|
|
self.swapout_G_freq = train_opt['swapout_G_freq'] if train_opt['swapout_G_freq'] else 0
|
|
self.swapout_G_duration = 0
|
|
self.swapout_D_freq = train_opt['swapout_D_freq'] if train_opt['swapout_D_freq'] else 0
|
|
self.swapout_D_duration = 0
|
|
self.swapout_duration = train_opt['swapout_duration'] if train_opt['swapout_duration'] else 0
|
|
|
|
self.print_network() # print network
|
|
self.load() # load G and D if needed
|
|
self.load_random_corruptor()
|
|
|
|
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
|
|
self.updated = True
|
|
|
|
def feed_data(self, data, need_GT=True):
|
|
_profile = True
|
|
if _profile:
|
|
from time import time
|
|
_t = time()
|
|
|
|
# Corrupt the data with the given corruptor, if specified.
|
|
self.fed_LQ = data['LQ'].to(self.device)
|
|
if self.netC and random.random() < self.corruptor_usage_prob:
|
|
with torch.no_grad():
|
|
corrupted_L = self.netC(self.fed_LQ)[0].detach()
|
|
else:
|
|
corrupted_L = self.fed_LQ
|
|
|
|
self.var_L = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0)
|
|
if need_GT:
|
|
self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
|
input_ref = data['ref'] if 'ref' in data else data['GT']
|
|
self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
|
self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
|
|
|
|
if not self.updated:
|
|
self.netG.module.update_model(self.optimizer_G, self.schedulers[0])
|
|
self.updated = True
|
|
|
|
def optimize_parameters(self, step):
|
|
_profile = False
|
|
if _profile:
|
|
from time import time
|
|
_t = time()
|
|
|
|
# Some generators have variants depending on the current step.
|
|
if hasattr(self.netG.module, "update_for_step"):
|
|
self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
|
if hasattr(self.netD.module, "update_for_step"):
|
|
self.netD.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
|
|
|
# G
|
|
for p in self.netD.parameters():
|
|
p.requires_grad = False
|
|
|
|
if step >= self.D_init_iters:
|
|
self.optimizer_G.zero_grad()
|
|
|
|
self.swapout_D(step)
|
|
self.swapout_G(step)
|
|
|
|
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
|
|
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
|
for p in self.netG.parameters():
|
|
if p.dtype != torch.int64 and p.dtype != torch.bool:
|
|
p.requires_grad = True
|
|
else:
|
|
for p in self.netG.parameters():
|
|
p.requires_grad = False
|
|
|
|
# Calculate a standard deviation for the gaussian noise to be applied to the discriminator, termed noise-theta.
|
|
if self.D_noise_final == 0:
|
|
noise_theta = 0
|
|
else:
|
|
noise_theta = (self.D_noise_theta - self.D_noise_theta_floor) * (self.D_noise_final - min(step, self.D_noise_final)) / self.D_noise_final + self.D_noise_theta_floor
|
|
|
|
if _profile:
|
|
print("Misc setup %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
self.fake_GenOut = []
|
|
self.fea_GenOut = []
|
|
self.fake_H = []
|
|
var_ref_skips = []
|
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
|
fea_GenOut, fake_GenOut = self.netG(var_L)
|
|
|
|
if _profile:
|
|
print("Gen forward %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
self.fake_GenOut.append(fake_GenOut.detach())
|
|
self.fea_GenOut.append(fea_GenOut.detach())
|
|
|
|
l_g_total = 0
|
|
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
|
if self.cri_pix: # pixel loss
|
|
l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix)
|
|
l_g_pix_log = l_g_pix / self.l_pix_w
|
|
l_g_total += l_g_pix
|
|
if self.cri_fea: # feature loss
|
|
real_fea = self.netF(pix).detach()
|
|
fake_fea = self.netF(fea_GenOut)
|
|
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
|
l_g_fea_log = l_g_fea / self.l_fea_w
|
|
l_g_total += l_g_fea
|
|
|
|
if _profile:
|
|
print("Fea forward %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
|
|
# in the resultant image.
|
|
if self.l_fea_w_decay_start and step > self.l_fea_w_decay_start:
|
|
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w_start - self.l_fea_w_decay_step_size * (step - self.l_fea_w_decay_start))
|
|
|
|
# Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931
|
|
# Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is
|
|
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
|
|
# it should target this value.
|
|
|
|
if self.l_gan_w > 0:
|
|
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
|
|
pred_g_fake = self.netD(fake_GenOut)
|
|
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
|
elif self.opt['train']['gan_type'] == 'ragan':
|
|
pred_d_real = self.netD(var_ref).detach()
|
|
pred_g_fake = self.netD(fake_GenOut)
|
|
l_g_gan = self.l_gan_w * (
|
|
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
|
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
|
l_g_gan_log = l_g_gan / self.l_gan_w
|
|
l_g_total += l_g_gan
|
|
|
|
# Scale the loss down by the batch factor.
|
|
l_g_total_log = l_g_total
|
|
l_g_total = l_g_total / self.mega_batch_factor
|
|
|
|
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
|
|
l_g_total_scaled.backward()
|
|
|
|
if _profile:
|
|
print("Gen backward %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
self.optimizer_G.step()
|
|
|
|
if _profile:
|
|
print("Gen step %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
# D
|
|
if self.l_gan_w > 0 and step > self.G_warmup:
|
|
for p in self.netD.parameters():
|
|
if p.dtype != torch.int64 and p.dtype != torch.bool:
|
|
p.requires_grad = True
|
|
|
|
noise = torch.randn_like(var_ref) * noise_theta
|
|
noise.to(self.device)
|
|
self.optimizer_D.zero_grad()
|
|
real_disc_images = []
|
|
fake_disc_images = []
|
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
|
# Re-compute generator outputs (post-update).
|
|
with torch.no_grad():
|
|
_, fake_H = self.netG(var_L)
|
|
# The following line detaches all generator outputs that are not None.
|
|
fake_H = fake_H.detach()
|
|
|
|
if _profile:
|
|
print("Gen forward for disc %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
# Apply noise to the inputs to slow discriminator convergence.
|
|
var_ref = var_ref + noise
|
|
fake_H = fake_H + noise
|
|
l_d_fea_real = 0
|
|
l_d_fea_fake = 0
|
|
if self.opt['train']['gan_type'] == 'pixgan_fea':
|
|
# Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better.
|
|
disc_fea_scale = .5
|
|
_, fea_real = self.netD(var_ref, output_feature_vector=True)
|
|
actual_fea = self.netF(var_ref)
|
|
l_d_fea_real = self.cri_fea(fea_real, actual_fea) * disc_fea_scale / self.mega_batch_factor
|
|
_, fea_fake = self.netD(fake_H, output_feature_vector=True)
|
|
actual_fea = self.netF(fake_H)
|
|
l_d_fea_fake = self.cri_fea(fea_fake, actual_fea) * disc_fea_scale / self.mega_batch_factor
|
|
if self.opt['train']['gan_type'] == 'gan':
|
|
# need to forward and backward separately, since batch norm statistics differ
|
|
# real
|
|
pred_d_real = self.netD(var_ref)
|
|
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
|
|
l_d_real_log = l_d_real * self.mega_batch_factor
|
|
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
|
l_d_real_scaled.backward()
|
|
# fake
|
|
pred_d_fake = self.netD(fake_H)
|
|
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
|
|
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
|
l_d_fake_scaled.backward()
|
|
if 'pixgan' in self.opt['train']['gan_type']:
|
|
# randomly determine portions of the image to swap to keep the discriminator honest.
|
|
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
|
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
|
|
b, _, w, h = var_ref.shape
|
|
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
|
|
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device)
|
|
SWAP_MAX_DIM = w // 4
|
|
SWAP_MIN_DIM = 16
|
|
assert SWAP_MAX_DIM > 0
|
|
if random.random() > .5: # Make this only happen half the time. Earlier experiments had it happen
|
|
# more often and the model was "cheating" by using the presence of
|
|
# easily discriminated fake swaps to count the entire generated image
|
|
# as fake.
|
|
random_swap_count = random.randint(0, 4)
|
|
for i in range(random_swap_count):
|
|
# Make the swap across fake_H and var_ref
|
|
swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM)
|
|
swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM)
|
|
if swap_x + swap_w > w:
|
|
swap_w = w - swap_x
|
|
if swap_y + swap_h > h:
|
|
swap_h = h - swap_y
|
|
t = fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
|
|
fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
|
|
var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
|
|
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
|
|
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
|
|
|
|
# Interpolate down to the dimensionality that the discriminator uses.
|
|
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear")
|
|
fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear")
|
|
|
|
# We're also assuming that this is exactly how the flattened discriminator output is generated.
|
|
real = real.view(-1, 1)
|
|
fake = fake.view(-1, 1)
|
|
|
|
# real
|
|
pred_d_real = self.netD(var_ref)
|
|
l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor
|
|
l_d_real_log = l_d_real * self.mega_batch_factor
|
|
l_d_real += l_d_fea_real
|
|
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
|
l_d_real_scaled.backward()
|
|
# fake
|
|
pred_d_fake = self.netD(fake_H)
|
|
l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor
|
|
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
|
l_d_fake += l_d_fea_fake
|
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
|
l_d_fake_scaled.backward()
|
|
|
|
pdr = pred_d_real.detach() + torch.abs(torch.min(pred_d_real))
|
|
pdr = pdr / torch.max(pdr)
|
|
real_disc_images.append(pdr.view(disc_output_shape))
|
|
pdf = pred_d_fake.detach() + torch.abs(torch.min(pred_d_fake))
|
|
pdf = pdf / torch.max(pdf)
|
|
fake_disc_images.append(pdf.view(disc_output_shape))
|
|
|
|
elif self.opt['train']['gan_type'] == 'ragan':
|
|
pred_d_fake = self.netD(fake_H).detach()
|
|
pred_d_real = self.netD(var_ref)
|
|
|
|
if _profile:
|
|
print("Double disc forward (RAGAN) %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
|
|
l_d_real_log = l_d_real * self.mega_batch_factor * 2
|
|
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
|
l_d_real_scaled.backward()
|
|
|
|
if _profile:
|
|
print("Disc backward 1 (RAGAN) %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
pred_d_fake = self.netD(fake_H)
|
|
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
|
|
l_d_fake_log = l_d_fake * self.mega_batch_factor * 2
|
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
|
l_d_fake_scaled.backward()
|
|
|
|
if _profile:
|
|
print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
# Append var_ref here, so that we can inspect the alterations the disc made if pixgan
|
|
var_ref_skips.append(var_ref.detach())
|
|
self.fake_H.append(fake_H.detach())
|
|
self.optimizer_D.step()
|
|
|
|
|
|
if _profile:
|
|
print("Disc step %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
# Log sample images from first microbatch.
|
|
if step % 50 == 0:
|
|
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
|
|
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
|
|
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
|
|
os.makedirs(os.path.join(sample_save_path, "gen_fea"), exist_ok=True)
|
|
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
|
|
os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True)
|
|
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
|
|
os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True)
|
|
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
|
|
|
|
# fed_LQ is not chunked.
|
|
for i in range(self.mega_batch_factor):
|
|
utils.save_image(self.var_H[i].cpu(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i)))
|
|
utils.save_image(self.var_L[i].cpu(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i)))
|
|
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
|
|
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
|
utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i)))
|
|
if self.l_gan_w > 0 and step > self.G_warmup and 'pixgan' in self.opt['train']['gan_type']:
|
|
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
|
|
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
|
|
utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i)))
|
|
utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i)))
|
|
|
|
# Log metrics
|
|
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
|
if self.cri_pix:
|
|
self.add_log_entry('l_g_pix', l_g_pix_log.item())
|
|
if self.cri_fea:
|
|
self.add_log_entry('feature_weight', self.l_fea_w)
|
|
self.add_log_entry('l_g_fea', l_g_fea_log.item())
|
|
if self.l_gan_w > 0:
|
|
self.add_log_entry('l_g_gan', l_g_gan_log.item())
|
|
self.add_log_entry('l_g_total', l_g_total_log.item())
|
|
if self.l_gan_w > 0 and step > self.G_warmup:
|
|
self.add_log_entry('l_d_real', l_d_real_log.item())
|
|
self.add_log_entry('l_d_fake', l_d_fake_log.item())
|
|
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
|
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
|
|
if self.opt['train']['gan_type'] == 'pixgan_fea':
|
|
self.add_log_entry('l_d_fea_fake', l_d_fea_fake.item() * self.mega_batch_factor)
|
|
self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor)
|
|
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
|
|
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
|
|
|
|
# Log learning rates.
|
|
for i, pg in enumerate(self.optimizer_G.param_groups):
|
|
self.add_log_entry('gen_lr_%i' % (i,), pg['lr'])
|
|
for i, pg in enumerate(self.optimizer_D.param_groups):
|
|
self.add_log_entry('disc_lr_%i' % (i,), pg['lr'])
|
|
|
|
if step % self.corruptor_swapout_steps == 0 and step > 0:
|
|
self.load_random_corruptor()
|
|
|
|
# Allows the log to serve as an easy-to-use rotating buffer.
|
|
def add_log_entry(self, key, value):
|
|
key_it = "%s_it" % (key,)
|
|
log_rotating_buffer_size = 50
|
|
if key not in self.log_dict.keys():
|
|
self.log_dict[key] = []
|
|
self.log_dict[key_it] = 0
|
|
if len(self.log_dict[key]) < log_rotating_buffer_size:
|
|
self.log_dict[key].append(value)
|
|
else:
|
|
self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value
|
|
self.log_dict[key_it] += 1
|
|
|
|
def pick_rand_prev_model(self, model_suffix):
|
|
previous_models = glob.glob(os.path.join(self.opt['path']['models'], "*_%s.pth" % (model_suffix,)))
|
|
if len(previous_models) <= 1:
|
|
return None
|
|
# Just a note: this intentionally includes the swap model in the list of possibilities.
|
|
return previous_models[random.randint(0, len(previous_models)-1)]
|
|
|
|
def compute_fea_loss(self, real, fake):
|
|
with torch.no_grad():
|
|
real = real.unsqueeze(dim=0)
|
|
fake = fake.unsqueeze(dim=0)
|
|
real_fea = self.netF(real).detach()
|
|
fake_fea = self.netF(fake)
|
|
return self.cri_fea(fake_fea, real_fea).item()
|
|
|
|
# Called before verification/checkpoint to ensure we're using the real models and not a swapout variant.
|
|
def force_restore_swapout(self):
|
|
if self.swapout_D_duration > 0:
|
|
logger.info("Swapping back to current D model: %s" % (self.stashed_D,))
|
|
self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load'])
|
|
self.stashed_D = None
|
|
self.swapout_D_duration = 0
|
|
if self.swapout_G_duration > 0:
|
|
logger.info("Swapping back to current G model: %s" % (self.stashed_G,))
|
|
self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load'])
|
|
self.stashed_G = None
|
|
self.swapout_G_duration = 0
|
|
|
|
def swapout_D(self, step):
|
|
if self.swapout_D_duration > 0:
|
|
self.swapout_D_duration -= 1
|
|
if self.swapout_D_duration == 0:
|
|
# Swap back.
|
|
logger.info("Swapping back to current D model: %s" % (self.stashed_D,))
|
|
self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load'])
|
|
self.stashed_D = None
|
|
elif self.swapout_D_freq != 0 and step % self.swapout_D_freq == 0:
|
|
swapped_model = self.pick_rand_prev_model('D')
|
|
if swapped_model is not None:
|
|
logger.info("Swapping to previous D model: %s" % (swapped_model,))
|
|
self.stashed_D = self.save_network(self.netD, 'D', 'swap_model')
|
|
self.load_network(swapped_model, self.netD, self.opt['path']['strict_load'])
|
|
self.swapout_D_duration = self.swapout_duration
|
|
|
|
def swapout_G(self, step):
|
|
if self.swapout_G_duration > 0:
|
|
self.swapout_G_duration -= 1
|
|
if self.swapout_G_duration == 0:
|
|
# Swap back.
|
|
logger.info("Swapping back to current G model: %s" % (self.stashed_G,))
|
|
self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load'])
|
|
self.stashed_G = None
|
|
elif self.swapout_G_freq != 0 and step % self.swapout_G_freq == 0:
|
|
swapped_model = self.pick_rand_prev_model('G')
|
|
if swapped_model is not None:
|
|
logger.info("Swapping to previous G model: %s" % (swapped_model,))
|
|
self.stashed_G = self.save_network(self.netG, 'G', 'swap_model')
|
|
self.load_network(swapped_model, self.netG, self.opt['path']['strict_load'])
|
|
self.swapout_G_duration = self.swapout_duration
|
|
|
|
def test(self):
|
|
self.netG.eval()
|
|
with torch.no_grad():
|
|
self.fake_GenOut = [self.netG(self.var_L[0])]
|
|
self.netG.train()
|
|
|
|
# Fetches a summary of the log.
|
|
def get_current_log(self, step):
|
|
return_log = {}
|
|
for k in self.log_dict.keys():
|
|
if not isinstance(self.log_dict[k], list):
|
|
continue
|
|
return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k])
|
|
|
|
# Some generators can do their own metric logging.
|
|
if hasattr(self.netG.module, "get_debug_values"):
|
|
return_log.update(self.netG.module.get_debug_values(step))
|
|
if hasattr(self.netD.module, "get_debug_values"):
|
|
return_log.update(self.netD.module.get_debug_values(step))
|
|
|
|
return return_log
|
|
|
|
def get_current_visuals(self, need_GT=True):
|
|
out_dict = OrderedDict()
|
|
out_dict['LQ'] = self.var_L[0].detach().float().cpu()
|
|
gen_batch = self.fake_GenOut[0]
|
|
if isinstance(gen_batch, tuple):
|
|
gen_batch = gen_batch[0]
|
|
out_dict['rlt'] = gen_batch.detach().float().cpu()
|
|
if need_GT:
|
|
out_dict['GT'] = self.var_H[0].detach().float().cpu()
|
|
return out_dict
|
|
|
|
def print_network(self):
|
|
# Generator
|
|
s, n = self.get_network_description(self.netG)
|
|
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
|
|
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
|
self.netG.module.__class__.__name__)
|
|
else:
|
|
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
|
if self.rank <= 0:
|
|
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
|
logger.info(s)
|
|
if self.is_train:
|
|
# Discriminator
|
|
s, n = self.get_network_description(self.netD)
|
|
if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
|
|
DistributedDataParallel):
|
|
net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
|
|
self.netD.module.__class__.__name__)
|
|
else:
|
|
net_struc_str = '{}'.format(self.netD.__class__.__name__)
|
|
if self.rank <= 0:
|
|
logger.info('Network D structure: {}, with parameters: {:,d}'.format(
|
|
net_struc_str, n))
|
|
logger.info(s)
|
|
|
|
if self.cri_fea: # F, Perceptual Network
|
|
s, n = self.get_network_description(self.netF)
|
|
if isinstance(self.netF, nn.DataParallel) or isinstance(
|
|
self.netF, DistributedDataParallel):
|
|
net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
|
|
self.netF.module.__class__.__name__)
|
|
else:
|
|
net_struc_str = '{}'.format(self.netF.__class__.__name__)
|
|
if self.rank <= 0:
|
|
logger.info('Network F structure: {}, with parameters: {:,d}'.format(
|
|
net_struc_str, n))
|
|
logger.info(s)
|
|
|
|
def load(self):
|
|
load_path_G = self.opt['path']['pretrain_model_G']
|
|
if load_path_G is not None:
|
|
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
|
|
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
|
|
load_path_D = self.opt['path']['pretrain_model_D']
|
|
if self.opt['is_train'] and load_path_D is not None:
|
|
logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
|
|
self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
|
|
|
|
def load_random_corruptor(self):
|
|
if self.netC is None:
|
|
return
|
|
corruptor_files = glob.glob(os.path.join(self.opt['path']['pretrained_corruptors_dir'], "*.pth"))
|
|
corruptor_to_load = corruptor_files[random.randint(0, len(corruptor_files)-1)]
|
|
logger.info('Swapping corruptor to: %s' % (corruptor_to_load,))
|
|
self.load_network(corruptor_to_load, self.netC, self.opt['path']['strict_load'])
|
|
|
|
def save(self, iter_step):
|
|
self.save_network(self.netG, 'G', iter_step)
|
|
self.save_network(self.netD, 'D', iter_step)
|