From e5a3e6b9b52a921f50636b4c6057915a5503e4e2 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 14 Dec 2020 23:59:49 -0700 Subject: [PATCH] srflow latent space misc --- .../scripts/srflow_latent_space_playground.py | 33 ++++++++++--------- codes/train.py | 3 +- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/codes/scripts/srflow_latent_space_playground.py b/codes/scripts/srflow_latent_space_playground.py index 7dd31902..9cacd16b 100644 --- a/codes/scripts/srflow_latent_space_playground.py +++ b/codes/scripts/srflow_latent_space_playground.py @@ -143,7 +143,8 @@ def slerp(latents1, latents2, proportion): return res -def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=10): +def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=10, prefix=''): + # Outputs a series of images interpolated from [latents1] to [latents2]. image 0 biases towards latents2. for i in range(steps): proportion = i / (steps-1) lats = linear_interpolation(latents1, latents2, proportion) @@ -153,7 +154,7 @@ def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=1 epses=lats, add_gt_noise=False) torchvision.transforms.ToPILImage()(hr.squeeze(0).cpu()) - torchvision.utils.save_image(hr.cpu(), os.path.join(output_file, "%i.png" % (i,))) + torchvision.utils.save_image(hr.cpu(), os.path.join(output_file, "%s_%i.png" % (prefix, i))) # Stopped using this because PILs animated gif output is total crap. #images[0].save(output_file, save_all=True, append_images=images[1:], duration=80, loop=0) @@ -163,7 +164,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True srg_analyze = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow_frompsnr.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow_bigboi_frompsnr.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt @@ -185,15 +186,15 @@ if __name__ == "__main__": imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*" #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*" scale = 2 - resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. + resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. E.g. set this to '2' to get 2x upsampling. temperature = 1 - output_path = "E:\\4k6k\\mmsr\\results\\latent_playground" + output_path = "..\\..\\results\\latent_playground" # Data types <- used to perform latent transfer. data_path = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half" - data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*joli_high*"] - #data_type_filters = ["*lanette*"] - max_size = 1100 # Should be set to 2x the largest single dimension of the input space, otherwise an error will occur. + #data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*joli_high*"] + data_type_filters = ["*lanette*"] + max_size = 1600 # Should be set to 2x the largest single dimension of the input space, otherwise an error will occur. max_ref_datatypes = 30 # Only picks this many images from the above data types to sample from. interpolation_steps = 30 @@ -204,9 +205,9 @@ if __name__ == "__main__": dt_imgs = [glob(os.path.join(data_path, p))[-5] for p in data_type_filters] dt_transfers = [image_2_tensor(i, max_size) for i in dt_imgs] # Downsample the images because they are often just too big to feed through the network (probably needs to be parameterized) - for j in range(len(dt_transfers)): - if min(dt_transfers[j].shape[2], dt_transfers[j].shape[3]) > 1600: - dt_transfers[j] = F.interpolate(dt_transfers[j], scale_factor=1 / 2, mode='area') + #for j in range(len(dt_transfers)): + # if min(dt_transfers[j].shape[2], dt_transfers[j].shape[3]) > 1600: + # dt_transfers[j] = F.interpolate(dt_transfers[j], scale_factor=1 / 2, mode='area') corruptor = ImageCorruptor({'fixed_corruptions': ['jpeg-medium', 'gaussian_blur_3']}) def corrupt_and_downsample(img, scale): img = F.interpolate(img, scale_factor=1 / scale, mode="area") @@ -219,11 +220,13 @@ if __name__ == "__main__": # Fetch the images to resample. img_files = glob(imgs_to_resample_pattern) - random.shuffle(img_files) + #random.shuffle(img_files) for im_it, img_file in enumerate(tqdm(img_files)): t = image_2_tensor(img_file).to(model.env['device']) - if resample_factor != 1: + if resample_factor > 1: t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic") + elif resample_factor < 1: + t = F.interpolate(t, scale_factor=resample_factor, mode="area") # Ensure the input image is a factor of 16. _, _, h, w = t.shape h = 16 * (h // 16) @@ -260,6 +263,6 @@ if __name__ == "__main__": for j in range(len(lats)): path = os.path.join(output_path, "%i_%i" % (im_it, j)) os.makedirs(path, exist_ok=True) - torchvision.utils.save_image(resample_img, os.path.join(path, "orig.jpg" %(im_it))) + torchvision.utils.save_image(resample_img, os.path.join(path, "orig_%i.jpg" % (im_it))) create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"), - path, [torch.zeros_like(l) for l in lats[j]], lats[j]) + path, [torch.zeros_like(l) for l in lats[j]], lats[j], prefix=mode) diff --git a/codes/train.py b/codes/train.py index 669dd90e..2e0ee82f 100644 --- a/codes/train.py +++ b/codes/train.py @@ -219,6 +219,7 @@ class Trainer: for b in range(len(val_data['GT_path'])): img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) + util.mkdir(img_dir) self.model.feed_data(val_data, self.current_step) @@ -292,7 +293,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_sbyol_512unsupervised.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb_bigboi_psnr_4x.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()