From 9a8f22750151ea9d47ea3bc45400af308c6ae000 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 26 Jul 2020 21:44:45 -0600 Subject: [PATCH] Allow separate dataset to pushed in for GAN-only training --- codes/data/LQGT_dataset.py | 23 +++++++++++++++++++++- codes/data/util.py | 2 ++ codes/models/SRGAN_model.py | 38 +++++++++++++++++++++++++++++-------- 3 files changed, 54 insertions(+), 9 deletions(-) diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index b17b21a4..183da765 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -27,6 +27,7 @@ class LQGTDataset(data.Dataset): self.paths_LQ, self.paths_GT = None, None self.sizes_LQ, self.sizes_GT = None, None self.paths_PIX, self.sizes_PIX = None, None + self.paths_GAN, self.sizes_GAN = None, None self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1 @@ -45,6 +46,10 @@ class LQGTDataset(data.Dataset): self.doCrop = opt['doCrop'] if 'dataroot_PIX' in opt.keys(): self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX']) + # dataroot_GAN is an alternative source of LR images specifically for use in computing the GAN loss, where + # LR and HR do not need to be paired. + if 'dataroot_GAN' in opt.keys(): + self.paths_GAN, self.sizes_GAN = util.get_image_paths(self.data_type, opt['dataroot_GAN']) assert self.paths_GT, 'Error: GT path is empty.' if self.paths_LQ and self.paths_GT: @@ -127,6 +132,11 @@ class LQGTDataset(data.Dataset): if img_LQ.ndim == 2: img_LQ = np.expand_dims(img_LQ, axis=2) + img_GAN = None + if self.paths_GAN: + GAN_path = self.paths_GAN[index % self.sizes_GAN] + img_GAN = util.read_img(self.LQ_env, GAN_path) + # Enforce force_resize constraints. h, w, _ = img_LQ.shape if h % self.force_multiple != 0 or w % self.force_multiple != 0: @@ -149,11 +159,15 @@ class LQGTDataset(data.Dataset): rnd_h = random.randint(0, max(0, H - LQ_size)) rnd_w = random.randint(0, max(0, W - LQ_size)) img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] + if img_GAN is not None: + img_GAN = img_GAN[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] else: img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) + if img_GAN is not None: + img_GAN = cv2.resize(img_GAN, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) @@ -186,6 +200,8 @@ class LQGTDataset(data.Dataset): if img_GT.shape[2] == 3: img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB) img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB) + if img_GAN is not None: + img_GAN = cv2.cvtColor(img_GAN, cv2.COLOR_BGR2RGB) img_PIX = cv2.cvtColor(img_PIX, cv2.COLOR_BGR2RGB) # LQ needs to go to a PIL image to perform the compression-artifact transformation. @@ -204,13 +220,18 @@ class LQGTDataset(data.Dataset): img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float() img_LQ = F.to_tensor(img_LQ) + if img_GAN is not None: + img_GAN = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GAN, (2, 0, 1)))).float() lq_noise = torch.randn_like(img_LQ) * 5 / 255 img_LQ += lq_noise if LQ_path is None: LQ_path = GT_path - return {'LQ': img_LQ, 'GT': img_GT, 'PIX': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path} + d = {'LQ': img_LQ, 'GT': img_GT, 'PIX': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path} + if img_GAN is not None: + d['GAN'] = img_GAN + return d def __len__(self): return len(self.paths_GT) diff --git a/codes/data/util.py b/codes/data/util.py index 4b7f5fc4..d7f7d302 100644 --- a/codes/data/util.py +++ b/codes/data/util.py @@ -62,8 +62,10 @@ def get_image_paths(data_type, dataroot, weights=[]): for j in range(extends): paths.extend(_get_paths_from_images(r)) paths = sorted(paths) + sizes = len(paths) else: paths = sorted(_get_paths_from_images(dataroot)) + sizes = len(paths) else: raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) return paths, sizes diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 98aa3caf..c651dbad 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -197,6 +197,9 @@ class SRGANModel(BaseModel): self.swapout_D_duration = 0 self.swapout_duration = train_opt['swapout_duration'] if train_opt['swapout_duration'] else 0 + # GAN LQ image params + self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0 + self.print_network() # print network self.load() # load G and D if needed self.load_random_corruptor() @@ -225,6 +228,13 @@ class SRGANModel(BaseModel): self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)] + if 'GAN' in data.keys(): + self.gan_img = [t.to(self.device) for t in torch.chunk(data['GAN'], chunks=self.mega_batch_factor, dim=0)] + else: + # If not provided, use provided LQ for anyplace where the GAN would have been used. + self.gan_img = self.var_L + self.gan_lq_img_use_prob = 0 # Safety valve for not goofing. + if not self.updated: self.netG.module.update_model(self.optimizer_G, self.schedulers[0]) self.updated = True @@ -274,8 +284,13 @@ class SRGANModel(BaseModel): self.fea_GenOut = [] self.fake_H = [] var_ref_skips = [] - for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): - fea_GenOut, fake_GenOut = self.netG(var_L) + for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix): + if random.random() > self.gan_lq_img_use_prob: + fea_GenOut, fake_GenOut = self.netG(var_L) + using_gan_img = False + else: + fea_GenOut, fake_GenOut = self.netG(var_LGAN) + using_gan_img = True if _profile: print("Gen forward %f" % (time() - _t,)) @@ -286,11 +301,14 @@ class SRGANModel(BaseModel): l_g_total = 0 if step % self.D_update_ratio == 0 and step >= self.D_init_iters: - if self.cri_pix: # pixel loss + if using_gan_img: + l_g_pix_log = None + l_g_fea_log = None + if self.cri_pix and not using_gan_img: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix) l_g_pix_log = l_g_pix / self.l_pix_w l_g_total += l_g_pix - if self.cri_fea: # feature loss + if self.cri_fea and not using_gan_img: # feature loss real_fea = self.netF(pix).detach() fake_fea = self.netF(fea_GenOut) fea_w = self.l_fea_sched.get_weight_for_step(step) @@ -348,10 +366,14 @@ class SRGANModel(BaseModel): self.optimizer_D.zero_grad() real_disc_images = [] fake_disc_images = [] - for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): + for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix): + if random.random() > self.gan_lq_img_use_prob: + gen_input = var_L + else: + gen_input = var_LGAN # Re-compute generator outputs (post-update). with torch.no_grad(): - _, fake_H = self.netG(var_L) + _, fake_H = self.netG(gen_input) # The following line detaches all generator outputs that are not None. fake_H = fake_H.detach() @@ -510,9 +532,9 @@ class SRGANModel(BaseModel): # Log metrics if step % self.D_update_ratio == 0 and step >= self.D_init_iters: - if self.cri_pix: + if self.cri_pix and l_g_pix_log is not None: self.add_log_entry('l_g_pix', l_g_pix_log.item()) - if self.cri_fea: + if self.cri_fea and l_g_fea_log is not None: self.add_log_entry('feature_weight', fea_w) self.add_log_entry('l_g_fea', l_g_fea_log.item()) if self.l_gan_w > 0: