forked from mrq/DL-Art-School
Support independent PIX dataroot
This commit is contained in:
parent
05aafef938
commit
4d269fdac6
|
@ -23,6 +23,8 @@ class LQGTDataset(data.Dataset):
|
|||
|
||||
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
||||
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
||||
if 'dataroot_PIX' in opt:
|
||||
self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX'])
|
||||
assert self.paths_GT, 'Error: GT path is empty.'
|
||||
if self.paths_LQ and self.paths_GT:
|
||||
assert len(self.paths_LQ) == len(
|
||||
|
@ -37,6 +39,9 @@ class LQGTDataset(data.Dataset):
|
|||
meminit=False)
|
||||
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
if 'dataroot_PIX' in self.opt:
|
||||
self.PIX_env = lmdb.open(self.opt['dataroot_PIX'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
||||
|
@ -55,6 +60,14 @@ class LQGTDataset(data.Dataset):
|
|||
if self.opt['color']: # change color space if necessary
|
||||
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
|
||||
|
||||
# get the pix image
|
||||
if self.PIX_path is not None:
|
||||
PIX_path = self.PIX_path[index]
|
||||
img_PIX = util.read_img(self.PIX_env, PIX_path, resolution)
|
||||
if self.opt['color']: # change color space if necessary
|
||||
img_PIX = util.channel_convert(img_PIX.shape[2], self.opt['color'], [img_PIX])[0]
|
||||
|
||||
|
||||
# get LQ image
|
||||
if self.paths_LQ:
|
||||
LQ_path = self.paths_LQ[index]
|
||||
|
|
|
@ -140,6 +140,10 @@ class SRGANModel(BaseModel):
|
|||
self.var_H = data['GT'].to(self.device) # GT
|
||||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||
self.var_ref = input_ref.to(self.device)
|
||||
if 'PIX' in data:
|
||||
self.pix = data['PIX']
|
||||
else:
|
||||
self.pix = self.var_H
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
|
||||
|
@ -148,7 +152,6 @@ class SRGANModel(BaseModel):
|
|||
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
|
||||
|
||||
|
||||
# G
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = False
|
||||
|
@ -159,7 +162,7 @@ class SRGANModel(BaseModel):
|
|||
l_g_total = 0
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix: # pixel loss
|
||||
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
|
||||
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.pix)
|
||||
l_g_total += l_g_pix
|
||||
if self.cri_fea: # feature loss
|
||||
real_fea = self.netF(self.var_H).detach()
|
||||
|
|
Loading…
Reference in New Issue
Block a user