diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index aae2ff89..83bc6a54 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -146,6 +146,12 @@ class LQGTDataset(data.Dataset): img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) + if self.opt['doResizeLoss']: + r = random.randrange(0, 10) + if r > 5: + img_LQ = cv2.resize(img_LQ, (int(LQ_size/2), int(LQ_size/2)), interpolation=cv2.INTER_LINEAR) + img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) + # augmentation - flip, rotate img_LQ, img_GT, img_PIX = util.augment([img_LQ, img_GT, img_PIX], self.opt['use_flip'], self.opt['use_rot']) diff --git a/codes/train.py b/codes/train.py index 99b2d768..67ad9dac 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_rrdb_xl_wideres.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_switched_rrdb_small.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -187,7 +187,7 @@ def main(): #### log if current_step % opt['logger']['print_freq'] == 0: - logs = model.get_current_log() + logs = model.get_current_log(current_step) message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v)