From f2422f1d754d838990719243dd6e1c331080ba8f Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 29 Nov 2020 09:33:29 -0700 Subject: [PATCH] Latent space playground --- .../scripts/srflow_latent_space_playground.py | 59 ++++++++++++++----- 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/codes/scripts/srflow_latent_space_playground.py b/codes/scripts/srflow_latent_space_playground.py index dac3a929..3fe89577 100644 --- a/codes/scripts/srflow_latent_space_playground.py +++ b/codes/scripts/srflow_latent_space_playground.py @@ -128,6 +128,36 @@ def extract_center_latent(ref, lat): return ref[:, :, math.floor(dh):-math.ceil(dh), math.floor(dw):-math.ceil(dw)] +def linear_interpolation(latents1, latents2, proportion): + return [l1*proportion+l2*(1-proportion) for l1, l2 in zip(latents1, latents2)] + + +def slerp(latents1, latents2, proportion): + res = [] + for low, high in zip(latents1, latents2): + low_norm = low / torch.norm(low, dim=[2,3], keepdim=True) + high_norm = high / torch.norm(high, dim=[2,3], keepdim=True) + omega = torch.acos((low_norm * high_norm).sum(1)) + so = torch.sin(omega) + res.append((torch.sin((1.0 - proportion) * omega) / so).unsqueeze(1) * low + (torch.sin(proportion * omega) / so).unsqueeze(1) * high) + return res + + +def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=10): + for i in range(steps): + proportion = i / (steps-1) + lats = linear_interpolation(latents1, latents2, proportion) + hr, _ = gen(lr=lq, + z=lats[0], + reverse=True, + epses=lats, + add_gt_noise=False) + torchvision.transforms.ToPILImage()(hr.squeeze(0).cpu()) + torchvision.utils.save_image(hr.cpu(), os.path.join(output_file, "%i.png" % (i,))) + # Stopped using this because PILs animated gif output is total crap. + #images[0].save(output_file, save_all=True, append_images=images[1:], duration=80, loop=0) + + if __name__ == "__main__": #### options torch.backends.cudnn.benchmark = True @@ -150,17 +180,18 @@ if __name__ == "__main__": gen = model.networks['generator'] gen.eval() - mode = "feed_through" # restore | latent_transfer | feed_through + mode = "temperature" # temperature | restore | latent_transfer | feed_through #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*" - imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\pure_adrianna_full\\images\\*" + #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*" + imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*" scale = 2 - resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. - temperature = .3 + resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. + temperature = 1 output_path = "E:\\4k6k\\mmsr\\results\\latent_playground" # Data types <- used to perform latent transfer. data_path = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half" - data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*x-art-1912*", "*joli_high*", "*stacy-cruz*"] + data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*joli_high*"] #data_type_filters = ["*lanette*"] max_size = 1100 # Should be set to 2x the largest single dimension of the input space, otherwise an error will occur. max_ref_datatypes = 30 # Only picks this many images from the above data types to sample from. @@ -183,7 +214,6 @@ if __name__ == "__main__": cvimg = torch2cv(img) cvimg = corruptor.corrupt_images([cvimg])[0] img = cv2torch(cvimg) - torchvision.utils.save_image(img, "corrupted_lq_%i.png" % (random.randint(0, 100),)) return img dt_latents = [fetch_latents_for_image(gen, i, scale, corrupt_and_downsample) for i in dt_transfers] @@ -214,6 +244,8 @@ if __name__ == "__main__": dts.append([extract_center_latent(sl, l) * temperature for l, sl in zip(latents, slat)]) latents = dts multiple_latents = True + elif mode == "temperature": + latents = [l * temperature for l in latents] # Re-compute each image with the new metrics if not multiple_latents: @@ -221,13 +253,8 @@ if __name__ == "__main__": else: lats = latents for j in range(len(lats)): - hr, _ = gen(lr=F.interpolate(resample_img, scale_factor=1/scale, mode="area"), - z=lats[j][0], - reverse=True, - epses=lats[j], - add_gt_noise=False) - if torch.isnan(torch.max(hr)): - continue - os.makedirs(os.path.join(output_path), exist_ok=True) - torchvision.utils.save_image(resample_img, os.path.join(output_path, "%i_orig.jpg" %(im_it))) - torchvision.utils.save_image(hr, os.path.join(output_path, "%i_%i.jpg" % (im_it,j))) + path = os.path.join(output_path, "%i_%i" % (im_it, j)) + os.makedirs(path, exist_ok=True) + torchvision.utils.save_image(resample_img, os.path.join(path, "%i_orig.jpg" %(im_it))) + create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"), + path, [torch.zeros_like(l) for l in lats[j]], lats[j])