diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 99068573..a691015e 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -200,6 +200,9 @@ class SRGANModel(BaseModel): self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan + # Scale the loss down by the batch factor. + l_g_total = l_g_total / self.mega_batch_factor + with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled: l_g_total_scaled.backward() self.optimizer_G.step() @@ -223,12 +226,12 @@ class SRGANModel(BaseModel): # need to forward and backward separately, since batch norm statistics differ # real pred_d_real = self.netD(var_ref) - l_d_real = self.cri_gan(pred_d_real, True) + l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: l_d_real_scaled.backward() # fake - pred_d_fake = self.netD(fake_H) # detach to avoid BP to G - l_d_fake = self.cri_gan(pred_d_fake, False) + pred_d_fake = self.netD(fake_H) + l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() elif self.opt['train']['gan_type'] == 'ragan': @@ -240,11 +243,11 @@ class SRGANModel(BaseModel): # l_d_total.backward() pred_d_fake = self.netD(fake_H).detach() pred_d_real = self.netD(var_ref) - l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: l_d_real_scaled.backward() pred_d_fake = self.netD(fake_H) - l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 + l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() self.optimizer_D.step() @@ -255,14 +258,24 @@ class SRGANModel(BaseModel): os.makedirs("temp/lr", exist_ok=True) os.makedirs("temp/gen", exist_ok=True) os.makedirs("temp/pix", exist_ok=True) - gen_batch = self.fake_GenOut[0] - if isinstance(gen_batch, tuple): - gen_batch = gen_batch[0] - for i in range(self.var_L[0].shape[0]): - utils.save_image(self.var_H[0][i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i))) - utils.save_image(self.var_L[0][i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i))) - utils.save_image(self.pix[0][i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i))) - utils.save_image(gen_batch[i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i))) + multi_gen = False + if isinstance(self.fake_GenOut[0], tuple): + os.makedirs("temp/genlr", exist_ok=True) + os.makedirs("temp/genmr", exist_ok=True) + os.makedirs("temp/ref", exist_ok=True) + multi_gen = True + for i in range(self.mega_batch_factor): + utils.save_image(self.var_H[i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i))) + utils.save_image(self.var_L[i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i))) + utils.save_image(self.pix[i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i))) + if multi_gen: + utils.save_image(self.fake_GenOut[i][0].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i))) + utils.save_image(self.fake_GenOut[i][1].cpu().detach(), os.path.join("temp/genmr", "%05i_%02i.png" % (step, i))) + utils.save_image(self.fake_GenOut[i][2].cpu().detach(), os.path.join("temp/genlr", "%05i_%02i.png" % (step, i))) + utils.save_image(var_ref_skips[i][1].cpu().detach(), os.path.join("temp/ref", "med_%05i_%02i.png" % (step, i))) + utils.save_image(var_ref_skips[i][2].cpu().detach(), os.path.join("temp/ref", "low_%05i_%02i.png" % (step, i))) + else: + utils.save_image(self.fake_GenOut[i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i))) # set log TODO(handle mega-batches?) if step % self.D_update_ratio == 0 and step > self.D_init_iters: @@ -272,9 +285,9 @@ class SRGANModel(BaseModel): self.log_dict['feature_weight'] = self.l_fea_w self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() - self.log_dict['l_g_total'] = l_g_total.item() - self.log_dict['l_d_real'] = l_d_real.item() - self.log_dict['l_d_fake'] = l_d_fake.item() + self.log_dict['l_g_total'] = l_g_total.item() * self.mega_batch_factor + self.log_dict['l_d_real'] = l_d_real.item() * self.mega_batch_factor + self.log_dict['l_d_fake'] = l_d_fake.item() * self.mega_batch_factor self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def create_artificial_skips(self, truth_img): diff --git a/codes/options/test/test_ESRGAN_adrianna_full.yml b/codes/options/test/test_ESRGAN_adrianna_full.yml index b54cd8a4..8d64c2b9 100644 --- a/codes/options/test/test_ESRGAN_adrianna_full.yml +++ b/codes/options/test/test_ESRGAN_adrianna_full.yml @@ -4,14 +4,14 @@ model: sr distortion: sr scale: 4 crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels -#gpu_ids: [0] +gpu_ids: [0] datasets: test_1: # the 1st test dataset name: set5 mode: LQ - batch_size: 1 - dataroot_LQ: E:\4k6k\datasets\adrianna\full_extract + batch_size: 16 + dataroot_LQ: ..\..\datasets\adrianna\full_extract #### network structures network_G: diff --git a/codes/options/train/train_ESRGAN_blacked_xl.yml b/codes/options/train/train_ESRGAN_blacked_xl.yml index 7f985e25..b4526337 100644 --- a/codes/options/train/train_ESRGAN_blacked_xl.yml +++ b/codes/options/train/train_ESRGAN_blacked_xl.yml @@ -16,8 +16,8 @@ datasets: dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted doCrop: false use_shuffle: true - n_workers: 8 # per GPU - batch_size: 6 + n_workers: 12 # per GPU + batch_size: 24 target_size: 256 color: RGB val: @@ -42,18 +42,18 @@ network_D: #### path path: - pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth - pretrain_model_D: ~ + #pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth + #pretrain_model_D: ~ strict_load: true resume_state: ~ #### training settings: learning rate scheme, loss train: - lr_G: !!float 1e-4 + lr_G: !!float 2e-4 weight_decay_G: 0 beta1_G: 0.9 beta2_G: 0.99 - lr_D: !!float 1e-4 + lr_D: !!float 4e-4 weight_decay_D: 0 beta1_D: 0.9 beta2_D: 0.99 @@ -63,7 +63,7 @@ train: warmup_iter: -1 # no warm up lr_steps: [20000, 40000, 50000, 60000] lr_gamma: 0.5 - mega_batch_factor: 1 + mega_batch_factor: 3 pixel_criterion: l1 pixel_weight: !!float 1e-2 diff --git a/codes/temp/cleanup.sh b/codes/temp/cleanup.sh index 95f03933..24562500 100644 --- a/codes/temp/cleanup.sh +++ b/codes/temp/cleanup.sh @@ -1,4 +1,7 @@ rm gen/* rm hr/* rm lr/* -rm pix/* \ No newline at end of file +rm pix/* +rm ref/* +rm genlr/* +rm genmr/* \ No newline at end of file