Fix megabatch scaling, log low and med-res gen images

This commit is contained in:
James Betker 2020-05-05 08:34:57 -06:00
parent 3b4e54c4c5
commit 9f4581aacb
4 changed files with 43 additions and 27 deletions

View File

@ -200,6 +200,9 @@ class SRGANModel(BaseModel):
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
l_g_total += l_g_gan l_g_total += l_g_gan
# Scale the loss down by the batch factor.
l_g_total = l_g_total / self.mega_batch_factor
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled: with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
l_g_total_scaled.backward() l_g_total_scaled.backward()
self.optimizer_G.step() self.optimizer_G.step()
@ -223,12 +226,12 @@ class SRGANModel(BaseModel):
# need to forward and backward separately, since batch norm statistics differ # need to forward and backward separately, since batch norm statistics differ
# real # real
pred_d_real = self.netD(var_ref) pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real, True) l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward() l_d_real_scaled.backward()
# fake # fake
pred_d_fake = self.netD(fake_H) # detach to avoid BP to G pred_d_fake = self.netD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake, False) l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
@ -240,11 +243,11 @@ class SRGANModel(BaseModel):
# l_d_total.backward() # l_d_total.backward()
pred_d_fake = self.netD(fake_H).detach() pred_d_fake = self.netD(fake_H).detach()
pred_d_real = self.netD(var_ref) pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward() l_d_real_scaled.backward()
pred_d_fake = self.netD(fake_H) pred_d_fake = self.netD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
self.optimizer_D.step() self.optimizer_D.step()
@ -255,14 +258,24 @@ class SRGANModel(BaseModel):
os.makedirs("temp/lr", exist_ok=True) os.makedirs("temp/lr", exist_ok=True)
os.makedirs("temp/gen", exist_ok=True) os.makedirs("temp/gen", exist_ok=True)
os.makedirs("temp/pix", exist_ok=True) os.makedirs("temp/pix", exist_ok=True)
gen_batch = self.fake_GenOut[0] multi_gen = False
if isinstance(gen_batch, tuple): if isinstance(self.fake_GenOut[0], tuple):
gen_batch = gen_batch[0] os.makedirs("temp/genlr", exist_ok=True)
for i in range(self.var_L[0].shape[0]): os.makedirs("temp/genmr", exist_ok=True)
utils.save_image(self.var_H[0][i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i))) os.makedirs("temp/ref", exist_ok=True)
utils.save_image(self.var_L[0][i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i))) multi_gen = True
utils.save_image(self.pix[0][i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i))) for i in range(self.mega_batch_factor):
utils.save_image(gen_batch[i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i))) utils.save_image(self.var_H[i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i)))
if multi_gen:
utils.save_image(self.fake_GenOut[i][0].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_GenOut[i][1].cpu().detach(), os.path.join("temp/genmr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_GenOut[i][2].cpu().detach(), os.path.join("temp/genlr", "%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][1].cpu().detach(), os.path.join("temp/ref", "med_%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][2].cpu().detach(), os.path.join("temp/ref", "low_%05i_%02i.png" % (step, i)))
else:
utils.save_image(self.fake_GenOut[i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
# set log TODO(handle mega-batches?) # set log TODO(handle mega-batches?)
if step % self.D_update_ratio == 0 and step > self.D_init_iters: if step % self.D_update_ratio == 0 and step > self.D_init_iters:
@ -272,9 +285,9 @@ class SRGANModel(BaseModel):
self.log_dict['feature_weight'] = self.l_fea_w self.log_dict['feature_weight'] = self.l_fea_w
self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_fea'] = l_g_fea.item()
self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_g_gan'] = l_g_gan.item()
self.log_dict['l_g_total'] = l_g_total.item() self.log_dict['l_g_total'] = l_g_total.item() * self.mega_batch_factor
self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_real'] = l_d_real.item() * self.mega_batch_factor
self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['l_d_fake'] = l_d_fake.item() * self.mega_batch_factor
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
def create_artificial_skips(self, truth_img): def create_artificial_skips(self, truth_img):

View File

@ -4,14 +4,14 @@ model: sr
distortion: sr distortion: sr
scale: 4 scale: 4
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
#gpu_ids: [0] gpu_ids: [0]
datasets: datasets:
test_1: # the 1st test dataset test_1: # the 1st test dataset
name: set5 name: set5
mode: LQ mode: LQ
batch_size: 1 batch_size: 16
dataroot_LQ: E:\4k6k\datasets\adrianna\full_extract dataroot_LQ: ..\..\datasets\adrianna\full_extract
#### network structures #### network structures
network_G: network_G:

View File

@ -16,8 +16,8 @@ datasets:
dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted
doCrop: false doCrop: false
use_shuffle: true use_shuffle: true
n_workers: 8 # per GPU n_workers: 12 # per GPU
batch_size: 6 batch_size: 24
target_size: 256 target_size: 256
color: RGB color: RGB
val: val:
@ -42,18 +42,18 @@ network_D:
#### path #### path
path: path:
pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth #pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth
pretrain_model_D: ~ #pretrain_model_D: ~
strict_load: true strict_load: true
resume_state: ~ resume_state: ~
#### training settings: learning rate scheme, loss #### training settings: learning rate scheme, loss
train: train:
lr_G: !!float 1e-4 lr_G: !!float 2e-4
weight_decay_G: 0 weight_decay_G: 0
beta1_G: 0.9 beta1_G: 0.9
beta2_G: 0.99 beta2_G: 0.99
lr_D: !!float 1e-4 lr_D: !!float 4e-4
weight_decay_D: 0 weight_decay_D: 0
beta1_D: 0.9 beta1_D: 0.9
beta2_D: 0.99 beta2_D: 0.99
@ -63,7 +63,7 @@ train:
warmup_iter: -1 # no warm up warmup_iter: -1 # no warm up
lr_steps: [20000, 40000, 50000, 60000] lr_steps: [20000, 40000, 50000, 60000]
lr_gamma: 0.5 lr_gamma: 0.5
mega_batch_factor: 1 mega_batch_factor: 3
pixel_criterion: l1 pixel_criterion: l1
pixel_weight: !!float 1e-2 pixel_weight: !!float 1e-2

View File

@ -2,3 +2,6 @@ rm gen/*
rm hr/* rm hr/*
rm lr/* rm lr/*
rm pix/* rm pix/*
rm ref/*
rm genlr/*
rm genmr/*