From 296135ec18740d52732ad9b438fddd1996d78700 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 8 Jun 2020 11:27:06 -0600 Subject: [PATCH] Add doResizeLoss to dataset doResizeLoss has a 50% chance to resize the LQ image to 50% size, then back to original size. This is useful to training a generator to recover these lost pixel values while also being able to do repairs on higher resolution images during training. --- codes/data/LQGT_dataset.py | 6 ++++++ codes/train.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index aae2ff89..83bc6a54 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -146,6 +146,12 @@ class LQGTDataset(data.Dataset): img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) + if self.opt['doResizeLoss']: + r = random.randrange(0, 10) + if r > 5: + img_LQ = cv2.resize(img_LQ, (int(LQ_size/2), int(LQ_size/2)), interpolation=cv2.INTER_LINEAR) + img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) + # augmentation - flip, rotate img_LQ, img_GT, img_PIX = util.augment([img_LQ, img_GT, img_PIX], self.opt['use_flip'], self.opt['use_rot']) diff --git a/codes/train.py b/codes/train.py index 99b2d768..67ad9dac 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_rrdb_xl_wideres.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_switched_rrdb_small.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -187,7 +187,7 @@ def main(): #### log if current_step % opt['logger']['print_freq'] == 0: - logs = model.get_current_log() + logs = model.get_current_log(current_step) message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v)