Support independent PIX dataroot

This commit is contained in:
James Betker 2020-04-22 00:40:13 -06:00
parent 05aafef938
commit 4d269fdac6
2 changed files with 18 additions and 2 deletions

View File

@ -23,6 +23,8 @@ class LQGTDataset(data.Dataset):
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
if 'dataroot_PIX' in opt:
self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX'])
assert self.paths_GT, 'Error: GT path is empty.'
if self.paths_LQ and self.paths_GT:
assert len(self.paths_LQ) == len(
@ -37,6 +39,9 @@ class LQGTDataset(data.Dataset):
meminit=False)
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
if 'dataroot_PIX' in self.opt:
self.PIX_env = lmdb.open(self.opt['dataroot_PIX'], readonly=True, lock=False, readahead=False,
meminit=False)
def __getitem__(self, index):
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
@ -55,6 +60,14 @@ class LQGTDataset(data.Dataset):
if self.opt['color']: # change color space if necessary
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
# get the pix image
if self.PIX_path is not None:
PIX_path = self.PIX_path[index]
img_PIX = util.read_img(self.PIX_env, PIX_path, resolution)
if self.opt['color']: # change color space if necessary
img_PIX = util.channel_convert(img_PIX.shape[2], self.opt['color'], [img_PIX])[0]
# get LQ image
if self.paths_LQ:
LQ_path = self.paths_LQ[index]

View File

@ -140,6 +140,10 @@ class SRGANModel(BaseModel):
self.var_H = data['GT'].to(self.device) # GT
input_ref = data['ref'] if 'ref' in data else data['GT']
self.var_ref = input_ref.to(self.device)
if 'PIX' in data:
self.pix = data['PIX']
else:
self.pix = self.var_H
def optimize_parameters(self, step):
@ -148,7 +152,6 @@ class SRGANModel(BaseModel):
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
# G
for p in self.netD.parameters():
p.requires_grad = False
@ -159,7 +162,7 @@ class SRGANModel(BaseModel):
l_g_total = 0
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix: # pixel loss
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.pix)
l_g_total += l_g_pix
if self.cri_fea: # feature loss
real_fea = self.netF(self.var_H).detach()