diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index 117908e8..ccff386f 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -19,12 +19,14 @@ class LQGTDataset(data.Dataset): self.data_type = self.opt['data_type'] self.paths_LQ, self.paths_GT = None, None self.sizes_LQ, self.sizes_GT = None, None - self.LQ_env, self.GT_env = None, None # environments for lmdb + self.paths_PIX, self.sizes_PIX = None, None + self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdb self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) - if 'dataroot_PIX' in opt: + if 'dataroot_PIX' in opt.keys(): self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX']) + assert self.paths_GT, 'Error: GT path is empty.' if self.paths_LQ and self.paths_GT: assert len(self.paths_LQ) == len( @@ -39,7 +41,7 @@ class LQGTDataset(data.Dataset): meminit=False) self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, meminit=False) - if 'dataroot_PIX' in self.opt: + if 'dataroot_PIX' in self.opt.keys(): self.PIX_env = lmdb.open(self.opt['dataroot_PIX'], readonly=True, lock=False, readahead=False, meminit=False) @@ -61,12 +63,13 @@ class LQGTDataset(data.Dataset): img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] # get the pix image - if self.PIX_path is not None: - PIX_path = self.PIX_path[index] + if self.paths_PIX is not None: + PIX_path = self.paths_PIX[index] img_PIX = util.read_img(self.PIX_env, PIX_path, resolution) if self.opt['color']: # change color space if necessary img_PIX = util.channel_convert(img_PIX.shape[2], self.opt['color'], [img_PIX])[0] - + else: + img_PIX = img_GT # get LQ image if self.paths_LQ: @@ -98,14 +101,8 @@ class LQGTDataset(data.Dataset): img_LQ = np.expand_dims(img_LQ, axis=2) if self.opt['phase'] == 'train': - # if the image size is too small H, W, _ = img_GT.shape - if H < GT_size or W < GT_size: - img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) - # using matlab imresize - img_LQ = util.imresize_np(img_GT, 1 / scale, True) - if img_LQ.ndim == 2: - img_LQ = np.expand_dims(img_LQ, axis=2) + assert H >= GT_size and W >= GT_size H, W, C = img_LQ.shape LQ_size = GT_size // scale @@ -116,9 +113,10 @@ class LQGTDataset(data.Dataset): img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] + img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] # augmentation - flip, rotate - img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], + img_LQ, img_GT, img_PIX = util.augment([img_LQ, img_GT, img_PIX], self.opt['use_flip'], self.opt['use_rot']) if self.opt['color']: # change color space if necessary @@ -129,12 +127,14 @@ class LQGTDataset(data.Dataset): if img_GT.shape[2] == 3: img_GT = img_GT[:, :, [2, 1, 0]] img_LQ = img_LQ[:, :, [2, 1, 0]] + img_PIX = img_PIX[:, :, [2, 1, 0]] img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float() img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() if LQ_path is None: LQ_path = GT_path - return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} + return {'LQ': img_LQ, 'GT': img_GT, 'PIX': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path} def __len__(self): return len(self.paths_GT) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 78032bfe..9af996f0 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -140,10 +140,7 @@ class SRGANModel(BaseModel): self.var_H = data['GT'].to(self.device) # GT input_ref = data['ref'] if 'ref' in data else data['GT'] self.var_ref = input_ref.to(self.device) - if 'PIX' in data: - self.pix = data['PIX'] - else: - self.pix = self.var_H + self.pix = data['PIX'].to(self.device) def optimize_parameters(self, step): @@ -151,6 +148,7 @@ class SRGANModel(BaseModel): for i in range(self.var_L.shape[0]): utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i))) utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i))) + utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i))) # G for p in self.netD.parameters(): @@ -165,7 +163,7 @@ class SRGANModel(BaseModel): l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.pix) l_g_total += l_g_pix if self.cri_fea: # feature loss - real_fea = self.netF(self.var_H).detach() + real_fea = self.netF(self.pix).detach() fake_fea = self.netF(self.fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea diff --git a/codes/options/train/finetune_ESRGAN_blacked.yml b/codes/options/train/finetune_ESRGAN_blacked.yml index 3a2666cf..3e379d2f 100644 --- a/codes/options/train/finetune_ESRGAN_blacked.yml +++ b/codes/options/train/finetune_ESRGAN_blacked.yml @@ -1,5 +1,5 @@ #### general settings -name: ESRGANx4_blacked_ft +name: ESRGANx4_blacked_lqprn use_tb_logger: true model: srgan distortion: sr @@ -13,7 +13,8 @@ datasets: name: blacked mode: LQGT dataroot_GT: ../datasets/blacked/train/hr - dataroot_LQ: ../datasets/blacked/train/lr + dataroot_LQ: ../datasets/lqprn/train/lr + dataroot_PIX: ../datasets/lqprn/train/hr use_shuffle: true n_workers: 4 # per GPU @@ -42,10 +43,10 @@ network_D: #### path path: - pretrain_model_G: ../experiments/ESRGANx4_blacked_ft/models/31500_G.pth - pretrain_model_D: ../experiments/ESRGANx4_blacked_ft/models/31500_D.pth + pretrain_model_G: ../experiments/blacked_gen_20000_epochs.pth + pretrain_model_D: ../experiments/blacked_disc_20000_epochs.pth + resume_state: ~ strict_load: true - resume_state: ../experiments/ESRGANx4_blacked_ft/training_state/31500.state #### training settings: learning rate scheme, loss train: @@ -65,7 +66,7 @@ train: lr_gamma: 0.5 pixel_criterion: l1 - pixel_weight: !!float 1e-2 + pixel_weight: !!float 5e-3 feature_criterion: l1 feature_weight: 1 gan_type: ragan # gan | ragan