Update to srflow_latent_space_playground

This commit is contained in:
James Betker 2020-11-26 20:31:21 -07:00
parent fd356580c0
commit 5f5420ff4a

View File

@ -66,7 +66,7 @@ def fetch_latents_for_image(gen, img, scale, lr_infer=interpolate_lr):
def fetch_latents_for_images(gen, imgs, scale, lr_infer=interpolate_lr): def fetch_latents_for_images(gen, imgs, scale, lr_infer=interpolate_lr):
latents = [] latents = []
for img in tqdm(imgs): for img in imgs:
z, _, _ = gen(gt=img, z, _, _ = gen(gt=img,
lr=lr_infer(img, scale), lr=lr_infer(img, scale),
epses=[], epses=[],
@ -115,7 +115,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
srg_analyze = False srg_analyze = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../experiments/train_exd_imgset_srflow/train_exd_imgset_srflow.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow8x.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt utils.util.loaded_options = opt
@ -132,85 +132,83 @@ if __name__ == "__main__":
gen = model.networks['generator'] gen = model.networks['generator']
gen.eval() gen.eval()
mode = "latent_transfer" mode = "restore" # restore | latent_transfer | feed_through
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*" #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\pure_adrianna_full\\images\\*"
desired_size = None # (640,640) # <- Required when doing style transfer. desired_size = None # (640,640) # <- Required when doing style transfer.
scale = 2 scale = 8
resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents.
temperature = .65 temperature = 1
output_path = "E:\\4k6k\\mmsr\\results\\latent_playground" output_path = "E:\\4k6k\\mmsr\\results\\latent_playground"
# Data types <- used to perform latent transfer. # Data types <- used to perform latent transfer.
data_path = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half" data_path = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half"
data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*x-art-1912*", "*joli_high*", "*stacy-cruz*"] #data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*x-art-1912*", "*joli_high*", "*stacy-cruz*"]
#data_type_filters = ["*lanette*"] data_type_filters = ["*lanette*"]
max_ref_datatypes = 30 # Only picks this many images from the above data types to sample from. max_ref_datatypes = 30 # Only picks this many images from the above data types to sample from.
interpolation_steps = 30 interpolation_steps = 30
with torch.no_grad(): with torch.no_grad():
# Fetch the images to resample. # Compute latent variables for the reference images.
resample_imgs = [] if mode == "latent_transfer":
img_files = glob(imgs_to_resample_pattern)
for i, img_file in enumerate(img_files):
if i > 5:
break
t = image_2_tensor(img_file, desired_size).to(model.env['device'])
if resample_factor != 1:
t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic")
resample_imgs.append(t)
# Fetch the latent metrics & latents for each image we are resampling.
latents = fetch_latents_for_images(gen, resample_imgs, scale)
multiple_latents = False
if mode == "restore":
for i, latent_set in enumerate(latents):
latents[i] = local_norm(spatial_norm(latent_set))
latents[i] = [l * temperature for l in latents[i]]
elif mode == "feed_through":
latents = [torch.randn_like(l) * temperature for l in latents[i]]
elif mode == "latent_transfer":
# Just get the **one** result for each pattern and use that latent. # Just get the **one** result for each pattern and use that latent.
dt_imgs = [glob(os.path.join(data_path, p))[-5] for p in data_type_filters] dt_imgs = [glob(os.path.join(data_path, p))[-5] for p in data_type_filters]
dt_transfers = [image_2_tensor(i, desired_size) for i in dt_imgs] dt_transfers = [image_2_tensor(i, desired_size) for i in dt_imgs]
# Downsample the images because they are often just too big to feed through the network (probably needs to be parameterized) # Downsample the images because they are often just too big to feed through the network (probably needs to be parameterized)
for j in range(len(dt_transfers)): for j in range(len(dt_transfers)):
if min(dt_transfers[j].shape[2], dt_transfers[j].shape[3]) > 1600: if min(dt_transfers[j].shape[2], dt_transfers[j].shape[3]) > 1600:
dt_transfers[j] = F.interpolate(dt_transfers[j], scale_factor=1/2, mode='area') dt_transfers[j] = F.interpolate(dt_transfers[j], scale_factor=1 / 2, mode='area')
corruptor = ImageCorruptor({'fixed_corruptions':['jpeg-low', 'gaussian_blur_5']}) corruptor = ImageCorruptor({'fixed_corruptions': ['jpeg-medium', 'gaussian_blur_3']})
def corrupt_and_downsample(img, scale): def corrupt_and_downsample(img, scale):
img = F.interpolate(img, scale_factor=1/scale, mode="area") img = F.interpolate(img, scale_factor=1 / scale, mode="area")
from data.util import torch2cv, cv2torch from data.util import torch2cv, cv2torch
cvimg = torch2cv(img) cvimg = torch2cv(img)
cvimg = corruptor.corrupt_images([cvimg])[0] cvimg = corruptor.corrupt_images([cvimg])[0]
img = cv2torch(cvimg) img = cv2torch(cvimg)
torchvision.utils.save_image(img, "corrupted_lq_%i.png" % (random.randint(0,100),)) torchvision.utils.save_image(img, "corrupted_lq_%i.png" % (random.randint(0, 100),))
return img return img
dt_latents = [fetch_latents_for_image(gen, i, scale, corrupt_and_downsample) for i in dt_transfers] dt_latents = [fetch_latents_for_image(gen, i, scale, corrupt_and_downsample) for i in dt_transfers]
tlatents = []
for lat in latents: # Fetch the images to resample.
img_files = glob(imgs_to_resample_pattern)
random.shuffle(img_files)
for im_it, img_file in enumerate(tqdm(img_files)):
t = image_2_tensor(img_file, desired_size).to(model.env['device'])
if resample_factor != 1:
t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic")
resample_img = t
# Fetch the latent metrics & latents for each image we are resampling.
latents = fetch_latents_for_images(gen, [resample_img], scale)[0]
multiple_latents = False
if mode == "restore":
latents = local_norm(spatial_norm(latents))
latents = [l * temperature for l in latents]
elif mode == "feed_through":
latents = [torch.randn_like(l) * temperature for l in latents]
elif mode == "latent_transfer":
dts = [] dts = []
for slat in dt_latents: for slat in dt_latents:
assert slat[0].shape[2] >= lat[0].shape[2] assert slat[0].shape[2] >= latents[0].shape[2]
assert slat[0].shape[3] >= lat[0].shape[3] assert slat[0].shape[3] >= latents[0].shape[3]
dts.append([sl[:,:,:l.shape[2],:l.shape[3]] * temperature for l, sl in zip(lat, slat)]) dts.append([sl[:,:,:l.shape[2],:l.shape[3]] * temperature for l, sl in zip(latents, slat)])
tlatents.append(dts) latents = dts
latents = tlatents multiple_latents = True
multiple_latents = True
# Re-compute each image with the new metrics # Re-compute each image with the new metrics
for i, img in enumerate(resample_imgs):
if not multiple_latents: if not multiple_latents:
lats = [latents[i]] lats = [latents]
else: else:
lats = latents[i] lats = latents
torchvision.utils.save_image(resample_img, os.path.join(output_path, "%i_orig.jpg" %(im_it)))
for j in range(len(lats)): for j in range(len(lats)):
hr, _ = gen(lr=F.interpolate(img, scale_factor=1/scale, mode="area"), hr, _ = gen(lr=F.interpolate(resample_img, scale_factor=1/scale, mode="area"),
z=lats[j][0], z=lats[j][0],
reverse=True, reverse=True,
epses=lats[j], epses=lats[j],
add_gt_noise=False) add_gt_noise=False)
if torch.isnan(torch.max(hr)):
continue
os.makedirs(os.path.join(output_path), exist_ok=True) os.makedirs(os.path.join(output_path), exist_ok=True)
torchvision.utils.save_image(hr, os.path.join(output_path, "%i_%i.png" % (i,j))) torchvision.utils.save_image(hr, os.path.join(output_path, "%i_%i.jpg" % (im_it,j)))