Allow separate dataset to pushed in for GAN-only training
This commit is contained in:
parent
b06e1784e1
commit
9a8f227501
|
@ -27,6 +27,7 @@ class LQGTDataset(data.Dataset):
|
|||
self.paths_LQ, self.paths_GT = None, None
|
||||
self.sizes_LQ, self.sizes_GT = None, None
|
||||
self.paths_PIX, self.sizes_PIX = None, None
|
||||
self.paths_GAN, self.sizes_GAN = None, None
|
||||
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs
|
||||
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
|
||||
|
||||
|
@ -45,6 +46,10 @@ class LQGTDataset(data.Dataset):
|
|||
self.doCrop = opt['doCrop']
|
||||
if 'dataroot_PIX' in opt.keys():
|
||||
self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX'])
|
||||
# dataroot_GAN is an alternative source of LR images specifically for use in computing the GAN loss, where
|
||||
# LR and HR do not need to be paired.
|
||||
if 'dataroot_GAN' in opt.keys():
|
||||
self.paths_GAN, self.sizes_GAN = util.get_image_paths(self.data_type, opt['dataroot_GAN'])
|
||||
|
||||
assert self.paths_GT, 'Error: GT path is empty.'
|
||||
if self.paths_LQ and self.paths_GT:
|
||||
|
@ -127,6 +132,11 @@ class LQGTDataset(data.Dataset):
|
|||
if img_LQ.ndim == 2:
|
||||
img_LQ = np.expand_dims(img_LQ, axis=2)
|
||||
|
||||
img_GAN = None
|
||||
if self.paths_GAN:
|
||||
GAN_path = self.paths_GAN[index % self.sizes_GAN]
|
||||
img_GAN = util.read_img(self.LQ_env, GAN_path)
|
||||
|
||||
# Enforce force_resize constraints.
|
||||
h, w, _ = img_LQ.shape
|
||||
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
|
||||
|
@ -149,11 +159,15 @@ class LQGTDataset(data.Dataset):
|
|||
rnd_h = random.randint(0, max(0, H - LQ_size))
|
||||
rnd_w = random.randint(0, max(0, W - LQ_size))
|
||||
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
||||
if img_GAN is not None:
|
||||
img_GAN = img_GAN[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
||||
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
|
||||
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
||||
img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
||||
else:
|
||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||
if img_GAN is not None:
|
||||
img_GAN = cv2.resize(img_GAN, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||
img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
|
@ -186,6 +200,8 @@ class LQGTDataset(data.Dataset):
|
|||
if img_GT.shape[2] == 3:
|
||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
|
||||
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
|
||||
if img_GAN is not None:
|
||||
img_GAN = cv2.cvtColor(img_GAN, cv2.COLOR_BGR2RGB)
|
||||
img_PIX = cv2.cvtColor(img_PIX, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
|
||||
|
@ -204,13 +220,18 @@ class LQGTDataset(data.Dataset):
|
|||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float()
|
||||
img_LQ = F.to_tensor(img_LQ)
|
||||
if img_GAN is not None:
|
||||
img_GAN = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GAN, (2, 0, 1)))).float()
|
||||
|
||||
lq_noise = torch.randn_like(img_LQ) * 5 / 255
|
||||
img_LQ += lq_noise
|
||||
|
||||
if LQ_path is None:
|
||||
LQ_path = GT_path
|
||||
return {'LQ': img_LQ, 'GT': img_GT, 'PIX': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
||||
d = {'LQ': img_LQ, 'GT': img_GT, 'PIX': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
||||
if img_GAN is not None:
|
||||
d['GAN'] = img_GAN
|
||||
return d
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths_GT)
|
||||
|
|
|
@ -62,8 +62,10 @@ def get_image_paths(data_type, dataroot, weights=[]):
|
|||
for j in range(extends):
|
||||
paths.extend(_get_paths_from_images(r))
|
||||
paths = sorted(paths)
|
||||
sizes = len(paths)
|
||||
else:
|
||||
paths = sorted(_get_paths_from_images(dataroot))
|
||||
sizes = len(paths)
|
||||
else:
|
||||
raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
|
||||
return paths, sizes
|
||||
|
|
|
@ -197,6 +197,9 @@ class SRGANModel(BaseModel):
|
|||
self.swapout_D_duration = 0
|
||||
self.swapout_duration = train_opt['swapout_duration'] if train_opt['swapout_duration'] else 0
|
||||
|
||||
# GAN LQ image params
|
||||
self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0
|
||||
|
||||
self.print_network() # print network
|
||||
self.load() # load G and D if needed
|
||||
self.load_random_corruptor()
|
||||
|
@ -225,6 +228,13 @@ class SRGANModel(BaseModel):
|
|||
self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
||||
self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
|
||||
|
||||
if 'GAN' in data.keys():
|
||||
self.gan_img = [t.to(self.device) for t in torch.chunk(data['GAN'], chunks=self.mega_batch_factor, dim=0)]
|
||||
else:
|
||||
# If not provided, use provided LQ for anyplace where the GAN would have been used.
|
||||
self.gan_img = self.var_L
|
||||
self.gan_lq_img_use_prob = 0 # Safety valve for not goofing.
|
||||
|
||||
if not self.updated:
|
||||
self.netG.module.update_model(self.optimizer_G, self.schedulers[0])
|
||||
self.updated = True
|
||||
|
@ -274,8 +284,13 @@ class SRGANModel(BaseModel):
|
|||
self.fea_GenOut = []
|
||||
self.fake_H = []
|
||||
var_ref_skips = []
|
||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||
for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
|
||||
if random.random() > self.gan_lq_img_use_prob:
|
||||
fea_GenOut, fake_GenOut = self.netG(var_L)
|
||||
using_gan_img = False
|
||||
else:
|
||||
fea_GenOut, fake_GenOut = self.netG(var_LGAN)
|
||||
using_gan_img = True
|
||||
|
||||
if _profile:
|
||||
print("Gen forward %f" % (time() - _t,))
|
||||
|
@ -286,11 +301,14 @@ class SRGANModel(BaseModel):
|
|||
|
||||
l_g_total = 0
|
||||
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
||||
if self.cri_pix: # pixel loss
|
||||
if using_gan_img:
|
||||
l_g_pix_log = None
|
||||
l_g_fea_log = None
|
||||
if self.cri_pix and not using_gan_img: # pixel loss
|
||||
l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix)
|
||||
l_g_pix_log = l_g_pix / self.l_pix_w
|
||||
l_g_total += l_g_pix
|
||||
if self.cri_fea: # feature loss
|
||||
if self.cri_fea and not using_gan_img: # feature loss
|
||||
real_fea = self.netF(pix).detach()
|
||||
fake_fea = self.netF(fea_GenOut)
|
||||
fea_w = self.l_fea_sched.get_weight_for_step(step)
|
||||
|
@ -348,10 +366,14 @@ class SRGANModel(BaseModel):
|
|||
self.optimizer_D.zero_grad()
|
||||
real_disc_images = []
|
||||
fake_disc_images = []
|
||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||
for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
|
||||
if random.random() > self.gan_lq_img_use_prob:
|
||||
gen_input = var_L
|
||||
else:
|
||||
gen_input = var_LGAN
|
||||
# Re-compute generator outputs (post-update).
|
||||
with torch.no_grad():
|
||||
_, fake_H = self.netG(var_L)
|
||||
_, fake_H = self.netG(gen_input)
|
||||
# The following line detaches all generator outputs that are not None.
|
||||
fake_H = fake_H.detach()
|
||||
|
||||
|
@ -510,9 +532,9 @@ class SRGANModel(BaseModel):
|
|||
|
||||
# Log metrics
|
||||
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
||||
if self.cri_pix:
|
||||
if self.cri_pix and l_g_pix_log is not None:
|
||||
self.add_log_entry('l_g_pix', l_g_pix_log.item())
|
||||
if self.cri_fea:
|
||||
if self.cri_fea and l_g_fea_log is not None:
|
||||
self.add_log_entry('feature_weight', fea_w)
|
||||
self.add_log_entry('l_g_fea', l_g_fea_log.item())
|
||||
if self.l_gan_w > 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user