forked from mrq/DL-Art-School
Fix megabatch scaling, log low and med-res gen images
This commit is contained in:
parent
3b4e54c4c5
commit
9f4581aacb
|
@ -200,6 +200,9 @@ class SRGANModel(BaseModel):
|
|||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
l_g_total += l_g_gan
|
||||
|
||||
# Scale the loss down by the batch factor.
|
||||
l_g_total = l_g_total / self.mega_batch_factor
|
||||
|
||||
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
|
||||
l_g_total_scaled.backward()
|
||||
self.optimizer_G.step()
|
||||
|
@ -223,12 +226,12 @@ class SRGANModel(BaseModel):
|
|||
# need to forward and backward separately, since batch norm statistics differ
|
||||
# real
|
||||
pred_d_real = self.netD(var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real, True)
|
||||
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
# fake
|
||||
pred_d_fake = self.netD(fake_H) # detach to avoid BP to G
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False)
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
|
@ -240,11 +243,11 @@ class SRGANModel(BaseModel):
|
|||
# l_d_total.backward()
|
||||
pred_d_fake = self.netD(fake_H).detach()
|
||||
pred_d_real = self.netD(var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
self.optimizer_D.step()
|
||||
|
@ -255,14 +258,24 @@ class SRGANModel(BaseModel):
|
|||
os.makedirs("temp/lr", exist_ok=True)
|
||||
os.makedirs("temp/gen", exist_ok=True)
|
||||
os.makedirs("temp/pix", exist_ok=True)
|
||||
gen_batch = self.fake_GenOut[0]
|
||||
if isinstance(gen_batch, tuple):
|
||||
gen_batch = gen_batch[0]
|
||||
for i in range(self.var_L[0].shape[0]):
|
||||
utils.save_image(self.var_H[0][i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.var_L[0][i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.pix[0][i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(gen_batch[i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
|
||||
multi_gen = False
|
||||
if isinstance(self.fake_GenOut[0], tuple):
|
||||
os.makedirs("temp/genlr", exist_ok=True)
|
||||
os.makedirs("temp/genmr", exist_ok=True)
|
||||
os.makedirs("temp/ref", exist_ok=True)
|
||||
multi_gen = True
|
||||
for i in range(self.mega_batch_factor):
|
||||
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.pix[i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i)))
|
||||
if multi_gen:
|
||||
utils.save_image(self.fake_GenOut[i][0].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_GenOut[i][1].cpu().detach(), os.path.join("temp/genmr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_GenOut[i][2].cpu().detach(), os.path.join("temp/genlr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(var_ref_skips[i][1].cpu().detach(), os.path.join("temp/ref", "med_%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(var_ref_skips[i][2].cpu().detach(), os.path.join("temp/ref", "low_%05i_%02i.png" % (step, i)))
|
||||
else:
|
||||
utils.save_image(self.fake_GenOut[i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
|
||||
|
||||
# set log TODO(handle mega-batches?)
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
|
@ -272,9 +285,9 @@ class SRGANModel(BaseModel):
|
|||
self.log_dict['feature_weight'] = self.l_fea_w
|
||||
self.log_dict['l_g_fea'] = l_g_fea.item()
|
||||
self.log_dict['l_g_gan'] = l_g_gan.item()
|
||||
self.log_dict['l_g_total'] = l_g_total.item()
|
||||
self.log_dict['l_d_real'] = l_d_real.item()
|
||||
self.log_dict['l_d_fake'] = l_d_fake.item()
|
||||
self.log_dict['l_g_total'] = l_g_total.item() * self.mega_batch_factor
|
||||
self.log_dict['l_d_real'] = l_d_real.item() * self.mega_batch_factor
|
||||
self.log_dict['l_d_fake'] = l_d_fake.item() * self.mega_batch_factor
|
||||
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
|
||||
|
||||
def create_artificial_skips(self, truth_img):
|
||||
|
|
|
@ -4,14 +4,14 @@ model: sr
|
|||
distortion: sr
|
||||
scale: 4
|
||||
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
||||
#gpu_ids: [0]
|
||||
gpu_ids: [0]
|
||||
|
||||
datasets:
|
||||
test_1: # the 1st test dataset
|
||||
name: set5
|
||||
mode: LQ
|
||||
batch_size: 1
|
||||
dataroot_LQ: E:\4k6k\datasets\adrianna\full_extract
|
||||
batch_size: 16
|
||||
dataroot_LQ: ..\..\datasets\adrianna\full_extract
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
|
|
|
@ -16,8 +16,8 @@ datasets:
|
|||
dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted
|
||||
doCrop: false
|
||||
use_shuffle: true
|
||||
n_workers: 8 # per GPU
|
||||
batch_size: 6
|
||||
n_workers: 12 # per GPU
|
||||
batch_size: 24
|
||||
target_size: 256
|
||||
color: RGB
|
||||
val:
|
||||
|
@ -42,18 +42,18 @@ network_D:
|
|||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth
|
||||
pretrain_model_D: ~
|
||||
#pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth
|
||||
#pretrain_model_D: ~
|
||||
strict_load: true
|
||||
resume_state: ~
|
||||
|
||||
#### training settings: learning rate scheme, loss
|
||||
train:
|
||||
lr_G: !!float 1e-4
|
||||
lr_G: !!float 2e-4
|
||||
weight_decay_G: 0
|
||||
beta1_G: 0.9
|
||||
beta2_G: 0.99
|
||||
lr_D: !!float 1e-4
|
||||
lr_D: !!float 4e-4
|
||||
weight_decay_D: 0
|
||||
beta1_D: 0.9
|
||||
beta2_D: 0.99
|
||||
|
@ -63,7 +63,7 @@ train:
|
|||
warmup_iter: -1 # no warm up
|
||||
lr_steps: [20000, 40000, 50000, 60000]
|
||||
lr_gamma: 0.5
|
||||
mega_batch_factor: 1
|
||||
mega_batch_factor: 3
|
||||
|
||||
pixel_criterion: l1
|
||||
pixel_weight: !!float 1e-2
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
rm gen/*
|
||||
rm hr/*
|
||||
rm lr/*
|
||||
rm pix/*
|
||||
rm pix/*
|
||||
rm ref/*
|
||||
rm genlr/*
|
||||
rm genmr/*
|
Loading…
Reference in New Issue
Block a user