forked from mrq/DL-Art-School
Latent space playground
This commit is contained in:
parent
a1d4c9f83c
commit
f2422f1d75
|
@ -128,6 +128,36 @@ def extract_center_latent(ref, lat):
|
||||||
return ref[:, :, math.floor(dh):-math.ceil(dh), math.floor(dw):-math.ceil(dw)]
|
return ref[:, :, math.floor(dh):-math.ceil(dh), math.floor(dw):-math.ceil(dw)]
|
||||||
|
|
||||||
|
|
||||||
|
def linear_interpolation(latents1, latents2, proportion):
|
||||||
|
return [l1*proportion+l2*(1-proportion) for l1, l2 in zip(latents1, latents2)]
|
||||||
|
|
||||||
|
|
||||||
|
def slerp(latents1, latents2, proportion):
|
||||||
|
res = []
|
||||||
|
for low, high in zip(latents1, latents2):
|
||||||
|
low_norm = low / torch.norm(low, dim=[2,3], keepdim=True)
|
||||||
|
high_norm = high / torch.norm(high, dim=[2,3], keepdim=True)
|
||||||
|
omega = torch.acos((low_norm * high_norm).sum(1))
|
||||||
|
so = torch.sin(omega)
|
||||||
|
res.append((torch.sin((1.0 - proportion) * omega) / so).unsqueeze(1) * low + (torch.sin(proportion * omega) / so).unsqueeze(1) * high)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=10):
|
||||||
|
for i in range(steps):
|
||||||
|
proportion = i / (steps-1)
|
||||||
|
lats = linear_interpolation(latents1, latents2, proportion)
|
||||||
|
hr, _ = gen(lr=lq,
|
||||||
|
z=lats[0],
|
||||||
|
reverse=True,
|
||||||
|
epses=lats,
|
||||||
|
add_gt_noise=False)
|
||||||
|
torchvision.transforms.ToPILImage()(hr.squeeze(0).cpu())
|
||||||
|
torchvision.utils.save_image(hr.cpu(), os.path.join(output_file, "%i.png" % (i,)))
|
||||||
|
# Stopped using this because PILs animated gif output is total crap.
|
||||||
|
#images[0].save(output_file, save_all=True, append_images=images[1:], duration=80, loop=0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#### options
|
#### options
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
@ -150,17 +180,18 @@ if __name__ == "__main__":
|
||||||
gen = model.networks['generator']
|
gen = model.networks['generator']
|
||||||
gen.eval()
|
gen.eval()
|
||||||
|
|
||||||
mode = "feed_through" # restore | latent_transfer | feed_through
|
mode = "temperature" # temperature | restore | latent_transfer | feed_through
|
||||||
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
|
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
|
||||||
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\pure_adrianna_full\\images\\*"
|
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
|
||||||
|
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
|
||||||
scale = 2
|
scale = 2
|
||||||
resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents.
|
resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents.
|
||||||
temperature = .3
|
temperature = 1
|
||||||
output_path = "E:\\4k6k\\mmsr\\results\\latent_playground"
|
output_path = "E:\\4k6k\\mmsr\\results\\latent_playground"
|
||||||
|
|
||||||
# Data types <- used to perform latent transfer.
|
# Data types <- used to perform latent transfer.
|
||||||
data_path = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half"
|
data_path = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half"
|
||||||
data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*x-art-1912*", "*joli_high*", "*stacy-cruz*"]
|
data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*joli_high*"]
|
||||||
#data_type_filters = ["*lanette*"]
|
#data_type_filters = ["*lanette*"]
|
||||||
max_size = 1100 # Should be set to 2x the largest single dimension of the input space, otherwise an error will occur.
|
max_size = 1100 # Should be set to 2x the largest single dimension of the input space, otherwise an error will occur.
|
||||||
max_ref_datatypes = 30 # Only picks this many images from the above data types to sample from.
|
max_ref_datatypes = 30 # Only picks this many images from the above data types to sample from.
|
||||||
|
@ -183,7 +214,6 @@ if __name__ == "__main__":
|
||||||
cvimg = torch2cv(img)
|
cvimg = torch2cv(img)
|
||||||
cvimg = corruptor.corrupt_images([cvimg])[0]
|
cvimg = corruptor.corrupt_images([cvimg])[0]
|
||||||
img = cv2torch(cvimg)
|
img = cv2torch(cvimg)
|
||||||
torchvision.utils.save_image(img, "corrupted_lq_%i.png" % (random.randint(0, 100),))
|
|
||||||
return img
|
return img
|
||||||
dt_latents = [fetch_latents_for_image(gen, i, scale, corrupt_and_downsample) for i in dt_transfers]
|
dt_latents = [fetch_latents_for_image(gen, i, scale, corrupt_and_downsample) for i in dt_transfers]
|
||||||
|
|
||||||
|
@ -214,6 +244,8 @@ if __name__ == "__main__":
|
||||||
dts.append([extract_center_latent(sl, l) * temperature for l, sl in zip(latents, slat)])
|
dts.append([extract_center_latent(sl, l) * temperature for l, sl in zip(latents, slat)])
|
||||||
latents = dts
|
latents = dts
|
||||||
multiple_latents = True
|
multiple_latents = True
|
||||||
|
elif mode == "temperature":
|
||||||
|
latents = [l * temperature for l in latents]
|
||||||
|
|
||||||
# Re-compute each image with the new metrics
|
# Re-compute each image with the new metrics
|
||||||
if not multiple_latents:
|
if not multiple_latents:
|
||||||
|
@ -221,13 +253,8 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
lats = latents
|
lats = latents
|
||||||
for j in range(len(lats)):
|
for j in range(len(lats)):
|
||||||
hr, _ = gen(lr=F.interpolate(resample_img, scale_factor=1/scale, mode="area"),
|
path = os.path.join(output_path, "%i_%i" % (im_it, j))
|
||||||
z=lats[j][0],
|
os.makedirs(path, exist_ok=True)
|
||||||
reverse=True,
|
torchvision.utils.save_image(resample_img, os.path.join(path, "%i_orig.jpg" %(im_it)))
|
||||||
epses=lats[j],
|
create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"),
|
||||||
add_gt_noise=False)
|
path, [torch.zeros_like(l) for l in lats[j]], lats[j])
|
||||||
if torch.isnan(torch.max(hr)):
|
|
||||||
continue
|
|
||||||
os.makedirs(os.path.join(output_path), exist_ok=True)
|
|
||||||
torchvision.utils.save_image(resample_img, os.path.join(output_path, "%i_orig.jpg" %(im_it)))
|
|
||||||
torchvision.utils.save_image(hr, os.path.join(output_path, "%i_%i.jpg" % (im_it,j)))
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user