Modifications that allow developer to explicitly specify a different image set for PIX and feature losses

This commit is contained in:
James Betker 2020-04-22 10:11:14 -06:00
parent 12d92dc443
commit 79aff886b5
3 changed files with 25 additions and 26 deletions

View File

@ -19,12 +19,14 @@ class LQGTDataset(data.Dataset):
self.data_type = self.opt['data_type']
self.paths_LQ, self.paths_GT = None, None
self.sizes_LQ, self.sizes_GT = None, None
self.LQ_env, self.GT_env = None, None # environments for lmdb
self.paths_PIX, self.sizes_PIX = None, None
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdb
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
if 'dataroot_PIX' in opt:
if 'dataroot_PIX' in opt.keys():
self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX'])
assert self.paths_GT, 'Error: GT path is empty.'
if self.paths_LQ and self.paths_GT:
assert len(self.paths_LQ) == len(
@ -39,7 +41,7 @@ class LQGTDataset(data.Dataset):
meminit=False)
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
if 'dataroot_PIX' in self.opt:
if 'dataroot_PIX' in self.opt.keys():
self.PIX_env = lmdb.open(self.opt['dataroot_PIX'], readonly=True, lock=False, readahead=False,
meminit=False)
@ -61,12 +63,13 @@ class LQGTDataset(data.Dataset):
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
# get the pix image
if self.PIX_path is not None:
PIX_path = self.PIX_path[index]
if self.paths_PIX is not None:
PIX_path = self.paths_PIX[index]
img_PIX = util.read_img(self.PIX_env, PIX_path, resolution)
if self.opt['color']: # change color space if necessary
img_PIX = util.channel_convert(img_PIX.shape[2], self.opt['color'], [img_PIX])[0]
else:
img_PIX = img_GT
# get LQ image
if self.paths_LQ:
@ -98,14 +101,8 @@ class LQGTDataset(data.Dataset):
img_LQ = np.expand_dims(img_LQ, axis=2)
if self.opt['phase'] == 'train':
# if the image size is too small
H, W, _ = img_GT.shape
if H < GT_size or W < GT_size:
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)
assert H >= GT_size and W >= GT_size
H, W, C = img_LQ.shape
LQ_size = GT_size // scale
@ -116,9 +113,10 @@ class LQGTDataset(data.Dataset):
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
# augmentation - flip, rotate
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
img_LQ, img_GT, img_PIX = util.augment([img_LQ, img_GT, img_PIX], self.opt['use_flip'],
self.opt['use_rot'])
if self.opt['color']: # change color space if necessary
@ -129,12 +127,14 @@ class LQGTDataset(data.Dataset):
if img_GT.shape[2] == 3:
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQ = img_LQ[:, :, [2, 1, 0]]
img_PIX = img_PIX[:, :, [2, 1, 0]]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float()
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
if LQ_path is None:
LQ_path = GT_path
return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
return {'LQ': img_LQ, 'GT': img_GT, 'PIX': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path}
def __len__(self):
return len(self.paths_GT)

View File

@ -140,10 +140,7 @@ class SRGANModel(BaseModel):
self.var_H = data['GT'].to(self.device) # GT
input_ref = data['ref'] if 'ref' in data else data['GT']
self.var_ref = input_ref.to(self.device)
if 'PIX' in data:
self.pix = data['PIX']
else:
self.pix = self.var_H
self.pix = data['PIX'].to(self.device)
def optimize_parameters(self, step):
@ -151,6 +148,7 @@ class SRGANModel(BaseModel):
for i in range(self.var_L.shape[0]):
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i)))
# G
for p in self.netD.parameters():
@ -165,7 +163,7 @@ class SRGANModel(BaseModel):
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.pix)
l_g_total += l_g_pix
if self.cri_fea: # feature loss
real_fea = self.netF(self.var_H).detach()
real_fea = self.netF(self.pix).detach()
fake_fea = self.netF(self.fake_H)
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fea

View File

@ -1,5 +1,5 @@
#### general settings
name: ESRGANx4_blacked_ft
name: ESRGANx4_blacked_lqprn
use_tb_logger: true
model: srgan
distortion: sr
@ -13,7 +13,8 @@ datasets:
name: blacked
mode: LQGT
dataroot_GT: ../datasets/blacked/train/hr
dataroot_LQ: ../datasets/blacked/train/lr
dataroot_LQ: ../datasets/lqprn/train/lr
dataroot_PIX: ../datasets/lqprn/train/hr
use_shuffle: true
n_workers: 4 # per GPU
@ -42,10 +43,10 @@ network_D:
#### path
path:
pretrain_model_G: ../experiments/ESRGANx4_blacked_ft/models/31500_G.pth
pretrain_model_D: ../experiments/ESRGANx4_blacked_ft/models/31500_D.pth
pretrain_model_G: ../experiments/blacked_gen_20000_epochs.pth
pretrain_model_D: ../experiments/blacked_disc_20000_epochs.pth
resume_state: ~
strict_load: true
resume_state: ../experiments/ESRGANx4_blacked_ft/training_state/31500.state
#### training settings: learning rate scheme, loss
train:
@ -65,7 +66,7 @@ train:
lr_gamma: 0.5
pixel_criterion: l1
pixel_weight: !!float 1e-2
pixel_weight: !!float 5e-3
feature_criterion: l1
feature_weight: 1
gan_type: ragan # gan | ragan