forked from mrq/DL-Art-School
Add doResizeLoss to dataset
doResizeLoss has a 50% chance to resize the LQ image to 50% size, then back to original size. This is useful to training a generator to recover these lost pixel values while also being able to do repairs on higher resolution images during training.
This commit is contained in:
parent
786a4288d6
commit
296135ec18
|
@ -146,6 +146,12 @@ class LQGTDataset(data.Dataset):
|
||||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||||
img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
if self.opt['doResizeLoss']:
|
||||||
|
r = random.randrange(0, 10)
|
||||||
|
if r > 5:
|
||||||
|
img_LQ = cv2.resize(img_LQ, (int(LQ_size/2), int(LQ_size/2)), interpolation=cv2.INTER_LINEAR)
|
||||||
|
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
# augmentation - flip, rotate
|
# augmentation - flip, rotate
|
||||||
img_LQ, img_GT, img_PIX = util.augment([img_LQ, img_GT, img_PIX], self.opt['use_flip'],
|
img_LQ, img_GT, img_PIX = util.augment([img_LQ, img_GT, img_PIX], self.opt['use_flip'],
|
||||||
self.opt['use_rot'])
|
self.opt['use_rot'])
|
||||||
|
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_rrdb_xl_wideres.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_switched_rrdb_small.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
@ -187,7 +187,7 @@ def main():
|
||||||
|
|
||||||
#### log
|
#### log
|
||||||
if current_step % opt['logger']['print_freq'] == 0:
|
if current_step % opt['logger']['print_freq'] == 0:
|
||||||
logs = model.get_current_log()
|
logs = model.get_current_log(current_step)
|
||||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
||||||
for v in model.get_current_learning_rate():
|
for v in model.get_current_learning_rate():
|
||||||
message += '{:.3e},'.format(v)
|
message += '{:.3e},'.format(v)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user