Allow separate dataset to pushed in for GAN-only training

This commit is contained in:
James Betker 2020-07-26 21:44:45 -06:00
parent b06e1784e1
commit 9a8f227501
3 changed files with 54 additions and 9 deletions

View File

@ -27,6 +27,7 @@ class LQGTDataset(data.Dataset):
self.paths_LQ, self.paths_GT = None, None self.paths_LQ, self.paths_GT = None, None
self.sizes_LQ, self.sizes_GT = None, None self.sizes_LQ, self.sizes_GT = None, None
self.paths_PIX, self.sizes_PIX = None, None self.paths_PIX, self.sizes_PIX = None, None
self.paths_GAN, self.sizes_GAN = None, None
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1 self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
@ -45,6 +46,10 @@ class LQGTDataset(data.Dataset):
self.doCrop = opt['doCrop'] self.doCrop = opt['doCrop']
if 'dataroot_PIX' in opt.keys(): if 'dataroot_PIX' in opt.keys():
self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX']) self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX'])
# dataroot_GAN is an alternative source of LR images specifically for use in computing the GAN loss, where
# LR and HR do not need to be paired.
if 'dataroot_GAN' in opt.keys():
self.paths_GAN, self.sizes_GAN = util.get_image_paths(self.data_type, opt['dataroot_GAN'])
assert self.paths_GT, 'Error: GT path is empty.' assert self.paths_GT, 'Error: GT path is empty.'
if self.paths_LQ and self.paths_GT: if self.paths_LQ and self.paths_GT:
@ -127,6 +132,11 @@ class LQGTDataset(data.Dataset):
if img_LQ.ndim == 2: if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2) img_LQ = np.expand_dims(img_LQ, axis=2)
img_GAN = None
if self.paths_GAN:
GAN_path = self.paths_GAN[index % self.sizes_GAN]
img_GAN = util.read_img(self.LQ_env, GAN_path)
# Enforce force_resize constraints. # Enforce force_resize constraints.
h, w, _ = img_LQ.shape h, w, _ = img_LQ.shape
if h % self.force_multiple != 0 or w % self.force_multiple != 0: if h % self.force_multiple != 0 or w % self.force_multiple != 0:
@ -149,11 +159,15 @@ class LQGTDataset(data.Dataset):
rnd_h = random.randint(0, max(0, H - LQ_size)) rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size)) rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
if img_GAN is not None:
img_GAN = img_GAN[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
else: else:
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
if img_GAN is not None:
img_GAN = cv2.resize(img_GAN, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
@ -186,6 +200,8 @@ class LQGTDataset(data.Dataset):
if img_GT.shape[2] == 3: if img_GT.shape[2] == 3:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB) img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB) img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
if img_GAN is not None:
img_GAN = cv2.cvtColor(img_GAN, cv2.COLOR_BGR2RGB)
img_PIX = cv2.cvtColor(img_PIX, cv2.COLOR_BGR2RGB) img_PIX = cv2.cvtColor(img_PIX, cv2.COLOR_BGR2RGB)
# LQ needs to go to a PIL image to perform the compression-artifact transformation. # LQ needs to go to a PIL image to perform the compression-artifact transformation.
@ -204,13 +220,18 @@ class LQGTDataset(data.Dataset):
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float() img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float()
img_LQ = F.to_tensor(img_LQ) img_LQ = F.to_tensor(img_LQ)
if img_GAN is not None:
img_GAN = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GAN, (2, 0, 1)))).float()
lq_noise = torch.randn_like(img_LQ) * 5 / 255 lq_noise = torch.randn_like(img_LQ) * 5 / 255
img_LQ += lq_noise img_LQ += lq_noise
if LQ_path is None: if LQ_path is None:
LQ_path = GT_path LQ_path = GT_path
return {'LQ': img_LQ, 'GT': img_GT, 'PIX': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path} d = {'LQ': img_LQ, 'GT': img_GT, 'PIX': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path}
if img_GAN is not None:
d['GAN'] = img_GAN
return d
def __len__(self): def __len__(self):
return len(self.paths_GT) return len(self.paths_GT)

View File

@ -62,8 +62,10 @@ def get_image_paths(data_type, dataroot, weights=[]):
for j in range(extends): for j in range(extends):
paths.extend(_get_paths_from_images(r)) paths.extend(_get_paths_from_images(r))
paths = sorted(paths) paths = sorted(paths)
sizes = len(paths)
else: else:
paths = sorted(_get_paths_from_images(dataroot)) paths = sorted(_get_paths_from_images(dataroot))
sizes = len(paths)
else: else:
raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
return paths, sizes return paths, sizes

View File

@ -197,6 +197,9 @@ class SRGANModel(BaseModel):
self.swapout_D_duration = 0 self.swapout_D_duration = 0
self.swapout_duration = train_opt['swapout_duration'] if train_opt['swapout_duration'] else 0 self.swapout_duration = train_opt['swapout_duration'] if train_opt['swapout_duration'] else 0
# GAN LQ image params
self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0
self.print_network() # print network self.print_network() # print network
self.load() # load G and D if needed self.load() # load G and D if needed
self.load_random_corruptor() self.load_random_corruptor()
@ -225,6 +228,13 @@ class SRGANModel(BaseModel):
self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)] self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
if 'GAN' in data.keys():
self.gan_img = [t.to(self.device) for t in torch.chunk(data['GAN'], chunks=self.mega_batch_factor, dim=0)]
else:
# If not provided, use provided LQ for anyplace where the GAN would have been used.
self.gan_img = self.var_L
self.gan_lq_img_use_prob = 0 # Safety valve for not goofing.
if not self.updated: if not self.updated:
self.netG.module.update_model(self.optimizer_G, self.schedulers[0]) self.netG.module.update_model(self.optimizer_G, self.schedulers[0])
self.updated = True self.updated = True
@ -274,8 +284,13 @@ class SRGANModel(BaseModel):
self.fea_GenOut = [] self.fea_GenOut = []
self.fake_H = [] self.fake_H = []
var_ref_skips = [] var_ref_skips = []
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
fea_GenOut, fake_GenOut = self.netG(var_L) if random.random() > self.gan_lq_img_use_prob:
fea_GenOut, fake_GenOut = self.netG(var_L)
using_gan_img = False
else:
fea_GenOut, fake_GenOut = self.netG(var_LGAN)
using_gan_img = True
if _profile: if _profile:
print("Gen forward %f" % (time() - _t,)) print("Gen forward %f" % (time() - _t,))
@ -286,11 +301,14 @@ class SRGANModel(BaseModel):
l_g_total = 0 l_g_total = 0
if step % self.D_update_ratio == 0 and step >= self.D_init_iters: if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
if self.cri_pix: # pixel loss if using_gan_img:
l_g_pix_log = None
l_g_fea_log = None
if self.cri_pix and not using_gan_img: # pixel loss
l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix) l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix)
l_g_pix_log = l_g_pix / self.l_pix_w l_g_pix_log = l_g_pix / self.l_pix_w
l_g_total += l_g_pix l_g_total += l_g_pix
if self.cri_fea: # feature loss if self.cri_fea and not using_gan_img: # feature loss
real_fea = self.netF(pix).detach() real_fea = self.netF(pix).detach()
fake_fea = self.netF(fea_GenOut) fake_fea = self.netF(fea_GenOut)
fea_w = self.l_fea_sched.get_weight_for_step(step) fea_w = self.l_fea_sched.get_weight_for_step(step)
@ -348,10 +366,14 @@ class SRGANModel(BaseModel):
self.optimizer_D.zero_grad() self.optimizer_D.zero_grad()
real_disc_images = [] real_disc_images = []
fake_disc_images = [] fake_disc_images = []
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
if random.random() > self.gan_lq_img_use_prob:
gen_input = var_L
else:
gen_input = var_LGAN
# Re-compute generator outputs (post-update). # Re-compute generator outputs (post-update).
with torch.no_grad(): with torch.no_grad():
_, fake_H = self.netG(var_L) _, fake_H = self.netG(gen_input)
# The following line detaches all generator outputs that are not None. # The following line detaches all generator outputs that are not None.
fake_H = fake_H.detach() fake_H = fake_H.detach()
@ -510,9 +532,9 @@ class SRGANModel(BaseModel):
# Log metrics # Log metrics
if step % self.D_update_ratio == 0 and step >= self.D_init_iters: if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
if self.cri_pix: if self.cri_pix and l_g_pix_log is not None:
self.add_log_entry('l_g_pix', l_g_pix_log.item()) self.add_log_entry('l_g_pix', l_g_pix_log.item())
if self.cri_fea: if self.cri_fea and l_g_fea_log is not None:
self.add_log_entry('feature_weight', fea_w) self.add_log_entry('feature_weight', fea_w)
self.add_log_entry('l_g_fea', l_g_fea_log.item()) self.add_log_entry('l_g_fea', l_g_fea_log.item())
if self.l_gan_w > 0: if self.l_gan_w > 0: