diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index 1d7645b4..117908e8 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -23,6 +23,8 @@ class LQGTDataset(data.Dataset): self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) + if 'dataroot_PIX' in opt: + self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX']) assert self.paths_GT, 'Error: GT path is empty.' if self.paths_LQ and self.paths_GT: assert len(self.paths_LQ) == len( @@ -37,6 +39,9 @@ class LQGTDataset(data.Dataset): meminit=False) self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, meminit=False) + if 'dataroot_PIX' in self.opt: + self.PIX_env = lmdb.open(self.opt['dataroot_PIX'], readonly=True, lock=False, readahead=False, + meminit=False) def __getitem__(self, index): if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): @@ -55,6 +60,14 @@ class LQGTDataset(data.Dataset): if self.opt['color']: # change color space if necessary img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] + # get the pix image + if self.PIX_path is not None: + PIX_path = self.PIX_path[index] + img_PIX = util.read_img(self.PIX_env, PIX_path, resolution) + if self.opt['color']: # change color space if necessary + img_PIX = util.channel_convert(img_PIX.shape[2], self.opt['color'], [img_PIX])[0] + + # get LQ image if self.paths_LQ: LQ_path = self.paths_LQ[index] diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 69dfb04c..78032bfe 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -140,6 +140,10 @@ class SRGANModel(BaseModel): self.var_H = data['GT'].to(self.device) # GT input_ref = data['ref'] if 'ref' in data else data['GT'] self.var_ref = input_ref.to(self.device) + if 'PIX' in data: + self.pix = data['PIX'] + else: + self.pix = self.var_H def optimize_parameters(self, step): @@ -148,7 +152,6 @@ class SRGANModel(BaseModel): utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i))) utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i))) - # G for p in self.netD.parameters(): p.requires_grad = False @@ -159,7 +162,7 @@ class SRGANModel(BaseModel): l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss - l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) + l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.pix) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach()