forked from mrq/DL-Art-School
0b7193392f
The latest discriminator architecture was already pretty much a unet. This one makes that official and uses shared layers. It also upsamples one additional time and throws out the lowest upsampling result. The intent is to delete the old vgg pixdisc, but I'll keep it around for a bit since I'm still trying out a few models with it.
641 lines
32 KiB
Python
641 lines
32 KiB
Python
import logging
|
|
from collections import OrderedDict
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
|
import models.networks as networks
|
|
import models.lr_scheduler as lr_scheduler
|
|
from models.base_model import BaseModel
|
|
from models.loss import GANLoss
|
|
from apex import amp
|
|
import torch.nn.functional as F
|
|
import glob
|
|
import random
|
|
|
|
import torchvision.utils as utils
|
|
import os
|
|
|
|
logger = logging.getLogger('base')
|
|
|
|
|
|
class SRGANModel(BaseModel):
|
|
def __init__(self, opt):
|
|
super(SRGANModel, self).__init__(opt)
|
|
if opt['dist']:
|
|
self.rank = torch.distributed.get_rank()
|
|
else:
|
|
self.rank = -1 # non dist training
|
|
train_opt = opt['train']
|
|
|
|
# define networks and load pretrained models
|
|
self.netG = networks.define_G(opt).to(self.device)
|
|
if self.is_train:
|
|
self.netD = networks.define_D(opt).to(self.device)
|
|
|
|
if 'network_C' in opt.keys():
|
|
self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
|
|
# The corruptor net is fixed. Lock 'her down.
|
|
self.netC.eval()
|
|
for p in self.netC.parameters():
|
|
p.requires_grad = True
|
|
else:
|
|
self.netC = None
|
|
|
|
# define losses, optimizer and scheduler
|
|
if self.is_train:
|
|
self.mega_batch_factor = train_opt['mega_batch_factor']
|
|
if self.mega_batch_factor is None:
|
|
self.mega_batch_factor = 1
|
|
# G pixel loss
|
|
if train_opt['pixel_weight'] > 0:
|
|
l_pix_type = train_opt['pixel_criterion']
|
|
if l_pix_type == 'l1':
|
|
self.cri_pix = nn.L1Loss().to(self.device)
|
|
elif l_pix_type == 'l2':
|
|
self.cri_pix = nn.MSELoss().to(self.device)
|
|
else:
|
|
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
|
self.l_pix_w = train_opt['pixel_weight']
|
|
else:
|
|
logger.info('Remove pixel loss.')
|
|
self.cri_pix = None
|
|
|
|
# G feature loss
|
|
if train_opt['feature_weight'] > 0:
|
|
l_fea_type = train_opt['feature_criterion']
|
|
if l_fea_type == 'l1':
|
|
self.cri_fea = nn.L1Loss().to(self.device)
|
|
elif l_fea_type == 'l2':
|
|
self.cri_fea = nn.MSELoss().to(self.device)
|
|
else:
|
|
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
|
self.l_fea_w = train_opt['feature_weight']
|
|
self.l_fea_w_decay = train_opt['feature_weight_decay']
|
|
self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps']
|
|
self.l_fea_w_minimum = train_opt['feature_weight_minimum']
|
|
else:
|
|
logger.info('Remove feature loss.')
|
|
self.cri_fea = None
|
|
if self.cri_fea: # load VGG perceptual loss
|
|
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
|
|
if opt['dist']:
|
|
pass # do not need to use DistributedDataParallel for netF
|
|
else:
|
|
self.netF = DataParallel(self.netF)
|
|
|
|
# GD gan loss
|
|
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
|
self.l_gan_w = train_opt['gan_weight']
|
|
# D_update_ratio and D_init_iters
|
|
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
|
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
|
self.G_warmup = train_opt['G_warmup'] if train_opt['G_warmup'] else 0
|
|
self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0
|
|
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
|
|
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
|
|
self.corruptor_swapout_steps = train_opt['corruptor_swapout_steps'] if train_opt['corruptor_swapout_steps'] else 500
|
|
self.corruptor_usage_prob = train_opt['corruptor_usage_probability'] if train_opt['corruptor_usage_probability'] else .5
|
|
|
|
# optimizers
|
|
# G
|
|
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
|
optim_params = []
|
|
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
|
if v.requires_grad:
|
|
optim_params.append(v)
|
|
else:
|
|
if self.rank <= 0:
|
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
|
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
|
weight_decay=wd_G,
|
|
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
|
|
self.optimizers.append(self.optimizer_G)
|
|
# D
|
|
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
|
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'],
|
|
weight_decay=wd_D,
|
|
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
|
|
self.optimizers.append(self.optimizer_D)
|
|
|
|
# AMP
|
|
[self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \
|
|
amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3)
|
|
|
|
# DataParallel
|
|
if opt['dist']:
|
|
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
|
else:
|
|
self.netG = DataParallel(self.netG)
|
|
if self.is_train:
|
|
if opt['dist']:
|
|
self.netD = DistributedDataParallel(self.netD,
|
|
device_ids=[torch.cuda.current_device()])
|
|
else:
|
|
self.netD = DataParallel(self.netD)
|
|
self.netG.train()
|
|
self.netD.train()
|
|
|
|
# schedulers
|
|
if train_opt['lr_scheme'] == 'MultiStepLR':
|
|
for optimizer in self.optimizers:
|
|
self.schedulers.append(
|
|
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
|
|
restarts=train_opt['restarts'],
|
|
weights=train_opt['restart_weights'],
|
|
gamma=train_opt['lr_gamma'],
|
|
clear_state=train_opt['clear_state'],
|
|
force_lr=train_opt['force_lr']))
|
|
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
|
for optimizer in self.optimizers:
|
|
self.schedulers.append(
|
|
lr_scheduler.CosineAnnealingLR_Restart(
|
|
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
|
|
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
|
|
else:
|
|
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
|
|
|
self.log_dict = OrderedDict()
|
|
|
|
# Swapout params
|
|
self.swapout_G_freq = train_opt['swapout_G_freq'] if train_opt['swapout_G_freq'] else 0
|
|
self.swapout_G_duration = 0
|
|
self.swapout_D_freq = train_opt['swapout_D_freq'] if train_opt['swapout_D_freq'] else 0
|
|
self.swapout_D_duration = 0
|
|
self.swapout_duration = train_opt['swapout_duration'] if train_opt['swapout_duration'] else 0
|
|
|
|
self.print_network() # print network
|
|
self.load() # load G and D if needed
|
|
self.load_random_corruptor()
|
|
|
|
def feed_data(self, data, need_GT=True):
|
|
_profile = True
|
|
if _profile:
|
|
from time import time
|
|
_t = time()
|
|
|
|
# Corrupt the data with the given corruptor, if specified.
|
|
self.fed_LQ = data['LQ'].to(self.device)
|
|
if self.netC and random.random() < self.corruptor_usage_prob:
|
|
with torch.no_grad():
|
|
corrupted_L = self.netC(self.fed_LQ)[0].detach()
|
|
else:
|
|
corrupted_L = self.fed_LQ
|
|
|
|
self.var_L = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0)
|
|
if need_GT:
|
|
self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
|
input_ref = data['ref'] if 'ref' in data else data['GT']
|
|
self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
|
self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
|
|
|
|
def optimize_parameters(self, step):
|
|
_profile = False
|
|
if _profile:
|
|
from time import time
|
|
_t = time()
|
|
|
|
# Some generators have variants depending on the current step.
|
|
if hasattr(self.netG.module, "update_for_step"):
|
|
self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
|
|
|
# G
|
|
for p in self.netD.parameters():
|
|
p.requires_grad = False
|
|
|
|
if step > self.D_init_iters:
|
|
self.optimizer_G.zero_grad()
|
|
|
|
self.swapout_D(step)
|
|
self.swapout_G(step)
|
|
|
|
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
|
|
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
|
for p in self.netG.parameters():
|
|
p.requires_grad = True
|
|
else:
|
|
for p in self.netG.parameters():
|
|
p.requires_grad = False
|
|
|
|
# Calculate a standard deviation for the gaussian noise to be applied to the discriminator, termed noise-theta.
|
|
if self.D_noise_final == 0:
|
|
noise_theta = 0
|
|
else:
|
|
noise_theta = (self.D_noise_theta - self.D_noise_theta_floor) * (self.D_noise_final - min(step, self.D_noise_final)) / self.D_noise_final + self.D_noise_theta_floor
|
|
|
|
if _profile:
|
|
print("Misc setup %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
self.fake_GenOut = []
|
|
self.fake_H = []
|
|
var_ref_skips = []
|
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
|
fake_GenOut = self.netG(var_L)
|
|
|
|
if _profile:
|
|
print("Gen forward %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
# Extract the image output. For generators that output skip-through connections, the master output is always
|
|
# the first element of the tuple.
|
|
if isinstance(fake_GenOut, tuple):
|
|
gen_img = fake_GenOut[0]
|
|
# The following line detaches all generator outputs that are not None.
|
|
self.fake_GenOut.append(tuple([(x.detach() if x is not None else None) for x in list(fake_GenOut)]))
|
|
var_ref = (var_ref,) # This is a tuple for legacy reasons.
|
|
else:
|
|
gen_img = fake_GenOut
|
|
self.fake_GenOut.append(fake_GenOut.detach())
|
|
|
|
l_g_total = 0
|
|
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
|
if self.cri_pix: # pixel loss
|
|
l_g_pix = self.l_pix_w * self.cri_pix(gen_img, pix)
|
|
l_g_pix_log = l_g_pix / self.l_pix_w
|
|
l_g_total += l_g_pix
|
|
if self.cri_fea: # feature loss
|
|
real_fea = self.netF(pix).detach()
|
|
fake_fea = self.netF(gen_img)
|
|
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
|
l_g_fea_log = l_g_fea / self.l_fea_w
|
|
l_g_total += l_g_fea
|
|
|
|
if _profile:
|
|
print("Fea forward %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
|
|
# in the resultant image.
|
|
if step % self.l_fea_w_decay_steps == 0:
|
|
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
|
|
|
if self.l_gan_w > 0:
|
|
if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan':
|
|
pred_g_fake = self.netD(fake_GenOut)
|
|
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
|
elif self.opt['train']['gan_type'] == 'ragan':
|
|
pred_d_real = self.netD(var_ref).detach()
|
|
pred_g_fake = self.netD(fake_GenOut)
|
|
l_g_gan = self.l_gan_w * (
|
|
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
|
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
|
l_g_gan_log = l_g_gan / self.l_gan_w
|
|
l_g_total += l_g_gan
|
|
|
|
# Scale the loss down by the batch factor.
|
|
l_g_total_log = l_g_total
|
|
l_g_total = l_g_total / self.mega_batch_factor
|
|
|
|
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
|
|
l_g_total_scaled.backward()
|
|
|
|
if _profile:
|
|
print("Gen backward %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
self.optimizer_G.step()
|
|
|
|
if _profile:
|
|
print("Gen step %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
# D
|
|
if self.l_gan_w > 0 and step > self.G_warmup:
|
|
for p in self.netD.parameters():
|
|
p.requires_grad = True
|
|
|
|
noise = torch.randn_like(var_ref[0]) * noise_theta
|
|
noise.to(self.device)
|
|
self.optimizer_D.zero_grad()
|
|
real_disc_images = []
|
|
fake_disc_images = []
|
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
|
# Re-compute generator outputs (post-update).
|
|
with torch.no_grad():
|
|
fake_H = self.netG(var_L)
|
|
# The following line detaches all generator outputs that are not None.
|
|
fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
|
|
|
|
if _profile:
|
|
print("Gen forward for disc %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
# Apply noise to the inputs to slow discriminator convergence.
|
|
var_ref = (var_ref + noise,)
|
|
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
|
if self.opt['train']['gan_type'] == 'gan':
|
|
# need to forward and backward separately, since batch norm statistics differ
|
|
# real
|
|
pred_d_real = self.netD(var_ref)
|
|
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
|
|
l_d_real_log = l_d_real * self.mega_batch_factor
|
|
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
|
l_d_real_scaled.backward()
|
|
# fake
|
|
pred_d_fake = self.netD(fake_H)
|
|
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
|
|
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
|
l_d_fake_scaled.backward()
|
|
if self.opt['train']['gan_type'] == 'pixgan':
|
|
# randomly determine portions of the image to swap to keep the discriminator honest.
|
|
pixdisc_channels, pixdisc_output_reduction = self.netD.pixgan_parameters()
|
|
disc_output_shape = (var_ref[0].shape[0], pixdisc_channels, var_ref[0].shape[2] // pixdisc_output_reduction, var_ref[0].shape[3] // pixdisc_output_reduction)
|
|
b, _, w, h = var_ref[0].shape
|
|
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref[0].device)
|
|
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref[0].device)
|
|
SWAP_MAX_DIM = w // 4
|
|
SWAP_MIN_DIM = 16
|
|
assert SWAP_MAX_DIM > 0
|
|
random_swap_count = random.randint(0, 4)
|
|
for i in range(random_swap_count):
|
|
# Make the swap across fake_H and var_ref
|
|
swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM)
|
|
swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM)
|
|
if swap_x + swap_w > w:
|
|
swap_w = w - swap_x
|
|
if swap_y + swap_h > h:
|
|
swap_h = h - swap_y
|
|
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
|
|
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
|
|
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
|
|
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
|
|
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
|
|
|
|
# Interpolate down to the dimensionality that the discriminator uses.
|
|
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear")
|
|
fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear")
|
|
|
|
# We're also assuming that this is exactly how the flattened discriminator output is generated.
|
|
real = real.view(-1, 1)
|
|
fake = fake.view(-1, 1)
|
|
|
|
# real
|
|
pred_d_real = self.netD(var_ref)
|
|
l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor
|
|
l_d_real_log = l_d_real * self.mega_batch_factor
|
|
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
|
l_d_real_scaled.backward()
|
|
# fake
|
|
pred_d_fake = self.netD(fake_H)
|
|
l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor
|
|
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
|
l_d_fake_scaled.backward()
|
|
|
|
pdr = pred_d_real.detach() + torch.abs(torch.min(pred_d_real))
|
|
pdr = pdr / torch.max(pdr)
|
|
real_disc_images.append(pdr.view(disc_output_shape))
|
|
pdf = pred_d_fake.detach() + torch.abs(torch.min(pred_d_fake))
|
|
pdf = pdf / torch.max(pdf)
|
|
fake_disc_images.append(pdf.view(disc_output_shape))
|
|
|
|
elif self.opt['train']['gan_type'] == 'ragan':
|
|
pred_d_fake = self.netD(fake_H).detach()
|
|
pred_d_real = self.netD(var_ref)
|
|
|
|
if _profile:
|
|
print("Double disc forward (RAGAN) %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
|
|
l_d_real_log = l_d_real * self.mega_batch_factor * 2
|
|
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
|
l_d_real_scaled.backward()
|
|
|
|
if _profile:
|
|
print("Disc backward 1 (RAGAN) %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
pred_d_fake = self.netD(fake_H)
|
|
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
|
|
l_d_fake_log = l_d_fake * self.mega_batch_factor * 2
|
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
|
l_d_fake_scaled.backward()
|
|
|
|
if _profile:
|
|
print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
# Append var_ref here, so that we can inspect the alterations the disc made if pixgan
|
|
var_ref_skips.append(var_ref[0].detach())
|
|
self.fake_H.append(fake_H[0].detach())
|
|
self.optimizer_D.step()
|
|
|
|
|
|
if _profile:
|
|
print("Disc step %f" % (time() - _t,))
|
|
_t = time()
|
|
|
|
# Log sample images from first microbatch.
|
|
if step % 50 == 0:
|
|
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
|
|
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
|
|
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
|
|
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
|
|
os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True)
|
|
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
|
|
os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True)
|
|
multi_gen = False
|
|
if isinstance(self.fake_GenOut[0], tuple):
|
|
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
|
|
multi_gen = True
|
|
|
|
# fed_LQ is not chunked.
|
|
for i in range(self.mega_batch_factor):
|
|
utils.save_image(self.var_H[i].cpu(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i)))
|
|
utils.save_image(self.var_L[i].cpu(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i)))
|
|
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
|
|
if multi_gen:
|
|
utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
|
if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan':
|
|
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
|
|
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
|
|
utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i)))
|
|
utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i)))
|
|
else:
|
|
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
|
|
|
# Log metrics
|
|
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
|
if self.cri_pix:
|
|
self.add_log_entry('l_g_pix', l_g_pix_log.item())
|
|
if self.cri_fea:
|
|
self.add_log_entry('feature_weight', self.l_fea_w)
|
|
self.add_log_entry('l_g_fea', l_g_fea_log.item())
|
|
if self.l_gan_w > 0:
|
|
self.add_log_entry('l_g_gan', l_g_gan_log.item())
|
|
self.add_log_entry('l_g_total', l_g_total_log.item())
|
|
if self.l_gan_w > 0 and step > self.G_warmup:
|
|
self.add_log_entry('l_d_real', l_d_real_log.item())
|
|
self.add_log_entry('l_d_fake', l_d_fake_log.item())
|
|
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
|
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
|
|
|
|
if step % self.corruptor_swapout_steps == 0 and step > 0:
|
|
self.load_random_corruptor()
|
|
|
|
# Allows the log to serve as an easy-to-use rotating buffer.
|
|
def add_log_entry(self, key, value):
|
|
key_it = "%s_it" % (key,)
|
|
log_rotating_buffer_size = 50
|
|
if key not in self.log_dict.keys():
|
|
self.log_dict[key] = []
|
|
self.log_dict[key_it] = 0
|
|
if len(self.log_dict[key]) < log_rotating_buffer_size:
|
|
self.log_dict[key].append(value)
|
|
else:
|
|
self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value
|
|
self.log_dict[key_it] += 1
|
|
|
|
def pick_rand_prev_model(self, model_suffix):
|
|
previous_models = glob.glob(os.path.join(self.opt['path']['models'], "*_%s.pth" % (model_suffix,)))
|
|
if len(previous_models) <= 1:
|
|
return None
|
|
# Just a note: this intentionally includes the swap model in the list of possibilities.
|
|
return previous_models[random.randint(0, len(previous_models)-1)]
|
|
|
|
def compute_fea_loss(self, real, fake):
|
|
with torch.no_grad():
|
|
real = real.unsqueeze(dim=0)
|
|
fake = fake.unsqueeze(dim=0)
|
|
real_fea = self.netF(real).detach()
|
|
fake_fea = self.netF(fake)
|
|
return self.cri_fea(fake_fea, real_fea).item()
|
|
|
|
# Called before verification/checkpoint to ensure we're using the real models and not a swapout variant.
|
|
def force_restore_swapout(self):
|
|
if self.swapout_D_duration > 0:
|
|
logger.info("Swapping back to current D model: %s" % (self.stashed_D,))
|
|
self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load'])
|
|
self.stashed_D = None
|
|
self.swapout_D_duration = 0
|
|
if self.swapout_G_duration > 0:
|
|
logger.info("Swapping back to current G model: %s" % (self.stashed_G,))
|
|
self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load'])
|
|
self.stashed_G = None
|
|
self.swapout_G_duration = 0
|
|
|
|
def swapout_D(self, step):
|
|
if self.swapout_D_duration > 0:
|
|
self.swapout_D_duration -= 1
|
|
if self.swapout_D_duration == 0:
|
|
# Swap back.
|
|
logger.info("Swapping back to current D model: %s" % (self.stashed_D,))
|
|
self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load'])
|
|
self.stashed_D = None
|
|
elif self.swapout_D_freq != 0 and step % self.swapout_D_freq == 0:
|
|
swapped_model = self.pick_rand_prev_model('D')
|
|
if swapped_model is not None:
|
|
logger.info("Swapping to previous D model: %s" % (swapped_model,))
|
|
self.stashed_D = self.save_network(self.netD, 'D', 'swap_model')
|
|
self.load_network(swapped_model, self.netD, self.opt['path']['strict_load'])
|
|
self.swapout_D_duration = self.swapout_duration
|
|
|
|
def swapout_G(self, step):
|
|
if self.swapout_G_duration > 0:
|
|
self.swapout_G_duration -= 1
|
|
if self.swapout_G_duration == 0:
|
|
# Swap back.
|
|
logger.info("Swapping back to current G model: %s" % (self.stashed_G,))
|
|
self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load'])
|
|
self.stashed_G = None
|
|
elif self.swapout_G_freq != 0 and step % self.swapout_G_freq == 0:
|
|
swapped_model = self.pick_rand_prev_model('G')
|
|
if swapped_model is not None:
|
|
logger.info("Swapping to previous G model: %s" % (swapped_model,))
|
|
self.stashed_G = self.save_network(self.netG, 'G', 'swap_model')
|
|
self.load_network(swapped_model, self.netG, self.opt['path']['strict_load'])
|
|
self.swapout_G_duration = self.swapout_duration
|
|
|
|
def test(self):
|
|
self.netG.eval()
|
|
with torch.no_grad():
|
|
self.fake_GenOut = [self.netG(self.var_L[0])]
|
|
self.netG.train()
|
|
|
|
# Fetches a summary of the log.
|
|
def get_current_log(self, step):
|
|
return_log = {}
|
|
for k in self.log_dict.keys():
|
|
if not isinstance(self.log_dict[k], list):
|
|
continue
|
|
return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k])
|
|
|
|
# Some generators can do their own metric logging.
|
|
if hasattr(self.netG.module, "get_debug_values"):
|
|
return_log.update(self.netG.module.get_debug_values(step))
|
|
|
|
return return_log
|
|
|
|
def get_current_visuals(self, need_GT=True):
|
|
out_dict = OrderedDict()
|
|
out_dict['LQ'] = self.var_L[0].detach().float().cpu()
|
|
gen_batch = self.fake_GenOut[0]
|
|
if isinstance(gen_batch, tuple):
|
|
gen_batch = gen_batch[0]
|
|
out_dict['rlt'] = gen_batch.detach().float().cpu()
|
|
if need_GT:
|
|
out_dict['GT'] = self.var_H[0].detach().float().cpu()
|
|
return out_dict
|
|
|
|
def print_network(self):
|
|
# Generator
|
|
s, n = self.get_network_description(self.netG)
|
|
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
|
|
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
|
self.netG.module.__class__.__name__)
|
|
else:
|
|
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
|
if self.rank <= 0:
|
|
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
|
logger.info(s)
|
|
if self.is_train:
|
|
# Discriminator
|
|
s, n = self.get_network_description(self.netD)
|
|
if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
|
|
DistributedDataParallel):
|
|
net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
|
|
self.netD.module.__class__.__name__)
|
|
else:
|
|
net_struc_str = '{}'.format(self.netD.__class__.__name__)
|
|
if self.rank <= 0:
|
|
logger.info('Network D structure: {}, with parameters: {:,d}'.format(
|
|
net_struc_str, n))
|
|
logger.info(s)
|
|
|
|
if self.cri_fea: # F, Perceptual Network
|
|
s, n = self.get_network_description(self.netF)
|
|
if isinstance(self.netF, nn.DataParallel) or isinstance(
|
|
self.netF, DistributedDataParallel):
|
|
net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
|
|
self.netF.module.__class__.__name__)
|
|
else:
|
|
net_struc_str = '{}'.format(self.netF.__class__.__name__)
|
|
if self.rank <= 0:
|
|
logger.info('Network F structure: {}, with parameters: {:,d}'.format(
|
|
net_struc_str, n))
|
|
logger.info(s)
|
|
|
|
def load(self):
|
|
load_path_G = self.opt['path']['pretrain_model_G']
|
|
if load_path_G is not None:
|
|
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
|
|
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
|
|
load_path_D = self.opt['path']['pretrain_model_D']
|
|
if self.opt['is_train'] and load_path_D is not None:
|
|
logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
|
|
self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
|
|
|
|
def load_random_corruptor(self):
|
|
if self.netC is None:
|
|
return
|
|
corruptor_files = glob.glob(os.path.join(self.opt['path']['pretrained_corruptors_dir'], "*.pth"))
|
|
corruptor_to_load = corruptor_files[random.randint(0, len(corruptor_files)-1)]
|
|
logger.info('Swapping corruptor to: %s' % (corruptor_to_load,))
|
|
self.load_network(corruptor_to_load, self.netC, self.opt['path']['strict_load'])
|
|
|
|
def save(self, iter_step):
|
|
self.save_network(self.netG, 'G', iter_step)
|
|
self.save_network(self.netD, 'D', iter_step)
|