diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index ed9129c1..e61295fd 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -31,6 +31,9 @@ class SRGANModel(BaseModel): # define losses, optimizer and scheduler if self.is_train: + self.mega_batch_factor = train_opt['mega_batch_factor'] + if self.mega_batch_factor is None: + self.mega_batch_factor = 1 # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] @@ -138,12 +141,12 @@ class SRGANModel(BaseModel): self.load() # load G and D if needed def feed_data(self, data, need_GT=True): - self.var_L = data['LQ'].to(self.device) # LQ + self.var_L = torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0) # LQ if need_GT: - self.var_H = data['GT'].to(self.device) # GT + self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] input_ref = data['ref'] if 'ref' in data else data['GT'] - self.var_ref = input_ref.to(self.device) - self.pix = data['PIX'].to(self.device) + self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] + self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)] def optimize_parameters(self, step): # G @@ -152,84 +155,95 @@ class SRGANModel(BaseModel): if step > self.D_init_iters: self.optimizer_G.zero_grad() - self.fake_H = self.netG(self.var_L) - else: - self.fake_H = self.pix - if step % 50 == 0: - for i in range(self.var_L.shape[0]): - utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i))) - utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i))) - utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i))) - utils.save_image(self.fake_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\gen", "%05i_%02i.png" % (step, i))) + self.fake_H = [] + for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): + if step > self.D_init_iters: + fake_H = self.netG(var_L) + else: + fake_H = pix + self.fake_H.append(fake_H.detach()) - l_g_total = 0 - if step % self.D_update_ratio == 0 and step > self.D_init_iters: - if self.cri_pix: # pixel loss - l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.pix) - l_g_total += l_g_pix - if self.cri_fea: # feature loss - real_fea = self.netF(self.pix).detach() - fake_fea = self.netF(self.fake_H) - l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) - l_g_total += l_g_fea + l_g_total = 0 + if step % self.D_update_ratio == 0 and step > self.D_init_iters: + if self.cri_pix: # pixel loss + l_g_pix = self.l_pix_w * self.cri_pix(fake_H, pix) + l_g_total += l_g_pix + if self.cri_fea: # feature loss + real_fea = self.netF(pix).detach() + fake_fea = self.netF(fake_H) + l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) + l_g_total += l_g_fea - # Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role - # in the resultant image. - if step % self.l_fea_w_decay_steps == 0: - self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay) + # Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role + # in the resultant image. + if step % self.l_fea_w_decay_steps == 0: + self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay) - if self.opt['train']['gan_type'] == 'gan': - pred_g_fake = self.netD(self.fake_H) - l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) - elif self.opt['train']['gan_type'] == 'ragan': - pred_d_real = self.netD(self.var_ref).detach() - pred_g_fake = self.netD(self.fake_H) - l_g_gan = self.l_gan_w * ( - self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + - self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 - l_g_total += l_g_gan + if self.opt['train']['gan_type'] == 'gan': + pred_g_fake = self.netD(fake_H) + l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) + elif self.opt['train']['gan_type'] == 'ragan': + pred_d_real = self.netD(var_ref).detach() + pred_g_fake = self.netD(fake_H) + l_g_gan = self.l_gan_w * ( + self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 + l_g_total += l_g_gan - with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled: - l_g_total_scaled.backward() - self.optimizer_G.step() + with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled: + l_g_total_scaled.backward() + self.optimizer_G.step() # D for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() - if self.opt['train']['gan_type'] == 'gan': - # need to forward and backward separately, since batch norm statistics differ - # real - pred_d_real = self.netD(self.var_ref) - l_d_real = self.cri_gan(pred_d_real, True) - with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: - l_d_real_scaled.backward() - # fake - pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G - l_d_fake = self.cri_gan(pred_d_fake, False) - with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: - l_d_fake_scaled.backward() - elif self.opt['train']['gan_type'] == 'ragan': - # pred_d_real = self.netD(self.var_ref) - # pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G - # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) - # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) - # l_d_total = (l_d_real + l_d_fake) / 2 - # l_d_total.backward() - pred_d_fake = self.netD(self.fake_H.detach()).detach() - pred_d_real = self.netD(self.var_ref) - l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 - with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: - l_d_real_scaled.backward() - pred_d_fake = self.netD(self.fake_H.detach()) - l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 - with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: - l_d_fake_scaled.backward() + for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, self.var_ref, self.pix, self.fake_H): + if self.opt['train']['gan_type'] == 'gan': + # need to forward and backward separately, since batch norm statistics differ + # real + pred_d_real = self.netD(var_ref) + l_d_real = self.cri_gan(pred_d_real, True) + with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: + l_d_real_scaled.backward() + # fake + pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G + l_d_fake = self.cri_gan(pred_d_fake, False) + with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: + l_d_fake_scaled.backward() + elif self.opt['train']['gan_type'] == 'ragan': + # pred_d_real = self.netD(var_ref) + # pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G + # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) + # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) + # l_d_total = (l_d_real + l_d_fake) / 2 + # l_d_total.backward() + pred_d_fake = self.netD(fake_H.detach()).detach() + pred_d_real = self.netD(var_ref) + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 + with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: + l_d_real_scaled.backward() + pred_d_fake = self.netD(fake_H.detach()) + l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 + with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: + l_d_fake_scaled.backward() self.optimizer_D.step() - # set log + # Log sample images from first microbatch. + if step % 50 == 0: + os.makedirs("temp/hr", exist_ok=True) + os.makedirs("temp/lr", exist_ok=True) + os.makedirs("temp/gen", exist_ok=True) + os.makedirs("temp/pix", exist_ok=True) + for i in range(self.var_L[0].shape[0]): + utils.save_image(self.var_H[0][i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i))) + utils.save_image(self.var_L[0][i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i))) + utils.save_image(self.pix[0][i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i))) + utils.save_image(self.fake_H[0][i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i))) + + # set log TODO(handle mega-batches?) if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() @@ -245,7 +259,7 @@ class SRGANModel(BaseModel): def test(self): self.netG.eval() with torch.no_grad(): - self.fake_H = self.netG(self.var_L) + self.fake_H = [self.netG(self.var_L[0])] self.netG.train() def get_current_log(self): @@ -253,10 +267,10 @@ class SRGANModel(BaseModel): def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() - out_dict['LQ'] = self.var_L.detach()[0].float().cpu() - out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() + out_dict['LQ'] = self.var_L[0].detach()[0].float().cpu() + out_dict['rlt'] = self.fake_H[0].detach()[0].float().cpu() if need_GT: - out_dict['GT'] = self.var_H.detach()[0].float().cpu() + out_dict['GT'] = self.var_H[0].detach()[0].float().cpu() return out_dict def print_network(self): diff --git a/codes/options/train/train_ESRGAN_blacked.yml b/codes/options/train/train_ESRGAN_blacked.yml index 64e1af42..86e25dd4 100644 --- a/codes/options/train/train_ESRGAN_blacked.yml +++ b/codes/options/train/train_ESRGAN_blacked.yml @@ -5,7 +5,7 @@ model: srgan distortion: sr scale: 4 gpu_ids: [0] -amp_opt_level: O1 +amp_opt_level: O0 #### datasets datasets: @@ -14,10 +14,10 @@ datasets: mode: LQGT dataroot_GT: K:\4k6k\4k_closeup\hr dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted - + doCrop: false use_shuffle: true n_workers: 12 # per GPU - batch_size: 12 + batch_size: 64 target_size: 256 color: RGB val: @@ -40,17 +40,18 @@ network_D: #### path path: - pretrain_model_G: ~ + pretrain_model_G: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_G.pth + pretrain_model_D: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_D.pth strict_load: true resume_state: ~ #### training settings: learning rate scheme, loss train: - lr_G: !!float 1e-4 + lr_G: !!float 5e-5 weight_decay_G: 0 beta1_G: 0.9 beta2_G: 0.99 - lr_D: !!float 2e-4 + lr_D: !!float 8e-5 weight_decay_D: 0 beta1_D: 0.9 beta2_D: 0.99 @@ -58,21 +59,22 @@ train: niter: 400000 warmup_iter: -1 # no warm up - lr_steps: [20000, 40000, 60000, 80000] + lr_steps: [5000, 20000, 40000, 60000] lr_gamma: 0.5 + mega_batch_factor: 8 pixel_criterion: l1 pixel_weight: !!float 1e-2 feature_criterion: l1 - feature_weight: 1 - feature_weight_decay: .98 + feature_weight: 0 + feature_weight_decay: .9 feature_weight_decay_steps: 500 feature_weight_minimum: .1 gan_type: gan # gan | ragan - gan_weight: !!float 5e-3 + gan_weight: 1 D_update_ratio: 1 - D_init_iters: 0 + D_init_iters: -1 manual_seed: 10 val_freq: !!float 5e2