diff --git a/codes/scripts/srflow_latent_space_playground.py b/codes/scripts/srflow_latent_space_playground.py index 6a58a4e2..621f671b 100644 --- a/codes/scripts/srflow_latent_space_playground.py +++ b/codes/scripts/srflow_latent_space_playground.py @@ -146,7 +146,7 @@ def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=1 for i in range(steps): proportion = i / (steps-1) lats = linear_interpolation(latents1, latents2, proportion) - hr, _ = gen(lr=lq, + hr, _, _ = gen(lr=lq, z=lats[0], reverse=True, epses=lats, @@ -179,9 +179,9 @@ if __name__ == "__main__": gen = model.networks['generator'] gen.eval() - mode = "restore" # temperature | restore | latent_transfer | feed_through - #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*" - imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*" + mode = "feed_through" # temperature | restore | latent_transfer | feed_through + imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*" + #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*" #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*" scale = 2 resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. E.g. set this to '2' to get 2x upsampling. diff --git a/codes/test.py b/codes/test.py index 495e05ff..073f8fc8 100644 --- a/codes/test.py +++ b/codes/test.py @@ -48,7 +48,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_4x_psnr.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_faces_glean.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt diff --git a/codes/train.py b/codes/train.py index ed925a88..1871f117 100644 --- a/codes/train.py +++ b/codes/train.py @@ -216,8 +216,8 @@ class Trainer: val_tqdm = tqdm(self.val_loader) for val_data in val_tqdm: idx += 1 - for b in range(len(val_data['GT_path'])): - img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0] + for b in range(len(val_data['HQ_path'])): + img_name = os.path.splitext(os.path.basename(val_data['HQ_path'][b]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir)