From 11d2b70bdd74b78b5c0d194417ad89df467f37f7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 27 Nov 2020 12:03:16 -0700 Subject: [PATCH] Latent space playground work --- .../scripts/srflow_latent_space_playground.py | 89 +++++++++++-------- 1 file changed, 54 insertions(+), 35 deletions(-) diff --git a/codes/scripts/srflow_latent_space_playground.py b/codes/scripts/srflow_latent_space_playground.py index 681b353c..dac3a929 100644 --- a/codes/scripts/srflow_latent_space_playground.py +++ b/codes/scripts/srflow_latent_space_playground.py @@ -20,14 +20,16 @@ from models.ExtensibleTrainer import ExtensibleTrainer from utils import util -def image_2_tensor(impath, desired_size): +def image_2_tensor(impath, max_size=None): img = Image.open(impath) - if desired_size is not None: - factor = max(desired_size[0] / img.width, desired_size[1] / img.height) + if max_size is not None: + factor = min(max_size / img.width, max_size / img.height) new_size = (int(math.ceil(img.width * factor)), int(math.ceil(img.height * factor))) - img = img.resize(new_size, Image.BICUBIC) + img = img.resize(new_size, Image.LANCZOS) + ''' + # Useful for setting an image to an exact size. h_gap = img.height - desired_size[1] w_gap = img.width - desired_size[0] assert h_gap >= 0 and w_gap >= 0 @@ -35,18 +37,19 @@ def image_2_tensor(impath, desired_size): hb = desired_size[1] + ht wl = w_gap // 2 wr = desired_size[1] + wl + ''' timg = torchvision.transforms.ToTensor()(img).unsqueeze(0) - if desired_size is not None: - timg = timg[:, :3, ht:hb, wl:wr] - assert timg.shape[2] == desired_size[1] and timg.shape[3] == desired_size[0] - else: - # Enforce that the input must have a input dimension that is a factor of 16. - b, c, h, w = timg.shape - h = (h // 16) * 16 - w = (w // 16) * 16 - timg = timg[:, :3, :h, :w] + #if desired_size is not None: + # timg = timg[:, :3, ht:hb, wl:wr] + # assert timg.shape[2] == desired_size[1] and timg.shape[3] == desired_size[0] + #else: + # Enforce that the input must have a input dimension that is a factor of 16. + b, c, h, w = timg.shape + h = (h // 16) * 16 + w = (w // 16) * 16 + timg = timg[:, :3, :h, :w] return timg @@ -88,34 +91,49 @@ def fetch_spatial_metrics_for_latents(latents): return dt_scales, dt_biases -def spatial_norm(latents): +def spatial_norm(latents, exclusion_list=[]): nlatents = [] for i in range(len(latents)): latent = latents[i] - b, c, h, w = latent.shape - s = latent.std(dim=[2, 3]).view(1,c,1,1) - b = latent.mean(dim=[2, 3]).view(1,c,1,1) - nlatents.append((latents[i] - b) / s) + if i in exclusion_list: + nlatents.append(latent) + else: + b, c, h, w = latent.shape + s = latent.std(dim=[2, 3]).view(1,c,1,1) + b = latent.mean(dim=[2, 3]).view(1,c,1,1) + nlatents.append((latents[i] - b) / s) return nlatents -def local_norm(latents): +def local_norm(latents, exclusion_list=[]): nlatents = [] for i in range(len(latents)): latent = latents[i] - b, c, h, w = latent.shape - s = latent.std(dim=[1]).view(1,1,h,w) - b = latent.mean(dim=[1]).view(1,1,h,w) - nlatents.append((latents[i] - b) / s) + if i in exclusion_list: + nlatents.append(latent) + else: + b, c, h, w = latent.shape + s = latent.std(dim=[1]).view(1,1,h,w) + b = latent.mean(dim=[1]).view(1,1,h,w) + nlatents.append((latents[i] - b) / s) return nlatents +# Extracts a rectangle of the same shape as from and returns it. This is taken from the center of +def extract_center_latent(ref, lat): + _, _, h, w = lat.shape + _, _, rh, rw = ref.shape + dw = (rw - w) / 2 + dh = (rh - h) / 2 + return ref[:, :, math.floor(dh):-math.ceil(dh), math.floor(dw):-math.ceil(dw)] + + if __name__ == "__main__": #### options torch.backends.cudnn.benchmark = True srg_analyze = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow8x.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../experiments/train_exd_imgset_srflow/train_exd_imgset_srflow.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt @@ -132,19 +150,19 @@ if __name__ == "__main__": gen = model.networks['generator'] gen.eval() - mode = "restore" # restore | latent_transfer | feed_through + mode = "feed_through" # restore | latent_transfer | feed_through #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*" imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\pure_adrianna_full\\images\\*" - desired_size = None # (640,640) # <- Required when doing style transfer. - scale = 8 - resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. - temperature = 1 + scale = 2 + resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. + temperature = .3 output_path = "E:\\4k6k\\mmsr\\results\\latent_playground" # Data types <- used to perform latent transfer. data_path = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half" - #data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*x-art-1912*", "*joli_high*", "*stacy-cruz*"] - data_type_filters = ["*lanette*"] + data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*x-art-1912*", "*joli_high*", "*stacy-cruz*"] + #data_type_filters = ["*lanette*"] + max_size = 1100 # Should be set to 2x the largest single dimension of the input space, otherwise an error will occur. max_ref_datatypes = 30 # Only picks this many images from the above data types to sample from. interpolation_steps = 30 @@ -153,7 +171,7 @@ if __name__ == "__main__": if mode == "latent_transfer": # Just get the **one** result for each pattern and use that latent. dt_imgs = [glob(os.path.join(data_path, p))[-5] for p in data_type_filters] - dt_transfers = [image_2_tensor(i, desired_size) for i in dt_imgs] + dt_transfers = [image_2_tensor(i, max_size) for i in dt_imgs] # Downsample the images because they are often just too big to feed through the network (probably needs to be parameterized) for j in range(len(dt_transfers)): if min(dt_transfers[j].shape[2], dt_transfers[j].shape[3]) > 1600: @@ -173,7 +191,7 @@ if __name__ == "__main__": img_files = glob(imgs_to_resample_pattern) random.shuffle(img_files) for im_it, img_file in enumerate(tqdm(img_files)): - t = image_2_tensor(img_file, desired_size).to(model.env['device']) + t = image_2_tensor(img_file).to(model.env['device']) if resample_factor != 1: t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic") resample_img = t @@ -184,6 +202,7 @@ if __name__ == "__main__": multiple_latents = False if mode == "restore": latents = local_norm(spatial_norm(latents)) + #latents = spatial_norm(latents) latents = [l * temperature for l in latents] elif mode == "feed_through": latents = [torch.randn_like(l) * temperature for l in latents] @@ -192,7 +211,7 @@ if __name__ == "__main__": for slat in dt_latents: assert slat[0].shape[2] >= latents[0].shape[2] assert slat[0].shape[3] >= latents[0].shape[3] - dts.append([sl[:,:,:l.shape[2],:l.shape[3]] * temperature for l, sl in zip(latents, slat)]) + dts.append([extract_center_latent(sl, l) * temperature for l, sl in zip(latents, slat)]) latents = dts multiple_latents = True @@ -201,7 +220,6 @@ if __name__ == "__main__": lats = [latents] else: lats = latents - torchvision.utils.save_image(resample_img, os.path.join(output_path, "%i_orig.jpg" %(im_it))) for j in range(len(lats)): hr, _ = gen(lr=F.interpolate(resample_img, scale_factor=1/scale, mode="area"), z=lats[j][0], @@ -211,4 +229,5 @@ if __name__ == "__main__": if torch.isnan(torch.max(hr)): continue os.makedirs(os.path.join(output_path), exist_ok=True) + torchvision.utils.save_image(resample_img, os.path.join(output_path, "%i_orig.jpg" %(im_it))) torchvision.utils.save_image(hr, os.path.join(output_path, "%i_%i.jpg" % (im_it,j)))