diff --git a/codes/scripts/srflow_latent_space_playground.py b/codes/scripts/srflow_latent_space_playground.py index ed057436..681b353c 100644 --- a/codes/scripts/srflow_latent_space_playground.py +++ b/codes/scripts/srflow_latent_space_playground.py @@ -66,7 +66,7 @@ def fetch_latents_for_image(gen, img, scale, lr_infer=interpolate_lr): def fetch_latents_for_images(gen, imgs, scale, lr_infer=interpolate_lr): latents = [] - for img in tqdm(imgs): + for img in imgs: z, _, _ = gen(gt=img, lr=lr_infer(img, scale), epses=[], @@ -115,7 +115,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True srg_analyze = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../experiments/train_exd_imgset_srflow/train_exd_imgset_srflow.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow8x.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt @@ -132,85 +132,83 @@ if __name__ == "__main__": gen = model.networks['generator'] gen.eval() - mode = "latent_transfer" - imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*" + mode = "restore" # restore | latent_transfer | feed_through + #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*" + imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\pure_adrianna_full\\images\\*" desired_size = None # (640,640) # <- Required when doing style transfer. - scale = 2 - resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. - temperature = .65 + scale = 8 + resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. + temperature = 1 output_path = "E:\\4k6k\\mmsr\\results\\latent_playground" # Data types <- used to perform latent transfer. data_path = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half" - data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*x-art-1912*", "*joli_high*", "*stacy-cruz*"] - #data_type_filters = ["*lanette*"] + #data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*x-art-1912*", "*joli_high*", "*stacy-cruz*"] + data_type_filters = ["*lanette*"] max_ref_datatypes = 30 # Only picks this many images from the above data types to sample from. interpolation_steps = 30 with torch.no_grad(): - # Fetch the images to resample. - resample_imgs = [] - img_files = glob(imgs_to_resample_pattern) - for i, img_file in enumerate(img_files): - if i > 5: - break - t = image_2_tensor(img_file, desired_size).to(model.env['device']) - if resample_factor != 1: - t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic") - resample_imgs.append(t) - - # Fetch the latent metrics & latents for each image we are resampling. - latents = fetch_latents_for_images(gen, resample_imgs, scale) - - multiple_latents = False - if mode == "restore": - for i, latent_set in enumerate(latents): - latents[i] = local_norm(spatial_norm(latent_set)) - latents[i] = [l * temperature for l in latents[i]] - elif mode == "feed_through": - latents = [torch.randn_like(l) * temperature for l in latents[i]] - elif mode == "latent_transfer": + # Compute latent variables for the reference images. + if mode == "latent_transfer": # Just get the **one** result for each pattern and use that latent. dt_imgs = [glob(os.path.join(data_path, p))[-5] for p in data_type_filters] dt_transfers = [image_2_tensor(i, desired_size) for i in dt_imgs] # Downsample the images because they are often just too big to feed through the network (probably needs to be parameterized) for j in range(len(dt_transfers)): if min(dt_transfers[j].shape[2], dt_transfers[j].shape[3]) > 1600: - dt_transfers[j] = F.interpolate(dt_transfers[j], scale_factor=1/2, mode='area') - corruptor = ImageCorruptor({'fixed_corruptions':['jpeg-low', 'gaussian_blur_5']}) - + dt_transfers[j] = F.interpolate(dt_transfers[j], scale_factor=1 / 2, mode='area') + corruptor = ImageCorruptor({'fixed_corruptions': ['jpeg-medium', 'gaussian_blur_3']}) def corrupt_and_downsample(img, scale): - img = F.interpolate(img, scale_factor=1/scale, mode="area") + img = F.interpolate(img, scale_factor=1 / scale, mode="area") from data.util import torch2cv, cv2torch cvimg = torch2cv(img) cvimg = corruptor.corrupt_images([cvimg])[0] img = cv2torch(cvimg) - torchvision.utils.save_image(img, "corrupted_lq_%i.png" % (random.randint(0,100),)) + torchvision.utils.save_image(img, "corrupted_lq_%i.png" % (random.randint(0, 100),)) return img - dt_latents = [fetch_latents_for_image(gen, i, scale, corrupt_and_downsample) for i in dt_transfers] - tlatents = [] - for lat in latents: + + # Fetch the images to resample. + img_files = glob(imgs_to_resample_pattern) + random.shuffle(img_files) + for im_it, img_file in enumerate(tqdm(img_files)): + t = image_2_tensor(img_file, desired_size).to(model.env['device']) + if resample_factor != 1: + t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic") + resample_img = t + + # Fetch the latent metrics & latents for each image we are resampling. + latents = fetch_latents_for_images(gen, [resample_img], scale)[0] + + multiple_latents = False + if mode == "restore": + latents = local_norm(spatial_norm(latents)) + latents = [l * temperature for l in latents] + elif mode == "feed_through": + latents = [torch.randn_like(l) * temperature for l in latents] + elif mode == "latent_transfer": dts = [] for slat in dt_latents: - assert slat[0].shape[2] >= lat[0].shape[2] - assert slat[0].shape[3] >= lat[0].shape[3] - dts.append([sl[:,:,:l.shape[2],:l.shape[3]] * temperature for l, sl in zip(lat, slat)]) - tlatents.append(dts) - latents = tlatents - multiple_latents = True + assert slat[0].shape[2] >= latents[0].shape[2] + assert slat[0].shape[3] >= latents[0].shape[3] + dts.append([sl[:,:,:l.shape[2],:l.shape[3]] * temperature for l, sl in zip(latents, slat)]) + latents = dts + multiple_latents = True - # Re-compute each image with the new metrics - for i, img in enumerate(resample_imgs): + # Re-compute each image with the new metrics if not multiple_latents: - lats = [latents[i]] + lats = [latents] else: - lats = latents[i] + lats = latents + torchvision.utils.save_image(resample_img, os.path.join(output_path, "%i_orig.jpg" %(im_it))) for j in range(len(lats)): - hr, _ = gen(lr=F.interpolate(img, scale_factor=1/scale, mode="area"), + hr, _ = gen(lr=F.interpolate(resample_img, scale_factor=1/scale, mode="area"), z=lats[j][0], reverse=True, epses=lats[j], add_gt_noise=False) + if torch.isnan(torch.max(hr)): + continue os.makedirs(os.path.join(output_path), exist_ok=True) - torchvision.utils.save_image(hr, os.path.join(output_path, "%i_%i.png" % (i,j))) + torchvision.utils.save_image(hr, os.path.join(output_path, "%i_%i.jpg" % (im_it,j)))