Latent space playground work
This commit is contained in:
parent
4ab49b0d69
commit
11d2b70bdd
|
@ -20,14 +20,16 @@ from models.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from utils import util
|
from utils import util
|
||||||
|
|
||||||
|
|
||||||
def image_2_tensor(impath, desired_size):
|
def image_2_tensor(impath, max_size=None):
|
||||||
img = Image.open(impath)
|
img = Image.open(impath)
|
||||||
|
|
||||||
if desired_size is not None:
|
if max_size is not None:
|
||||||
factor = max(desired_size[0] / img.width, desired_size[1] / img.height)
|
factor = min(max_size / img.width, max_size / img.height)
|
||||||
new_size = (int(math.ceil(img.width * factor)), int(math.ceil(img.height * factor)))
|
new_size = (int(math.ceil(img.width * factor)), int(math.ceil(img.height * factor)))
|
||||||
img = img.resize(new_size, Image.BICUBIC)
|
img = img.resize(new_size, Image.LANCZOS)
|
||||||
|
|
||||||
|
'''
|
||||||
|
# Useful for setting an image to an exact size.
|
||||||
h_gap = img.height - desired_size[1]
|
h_gap = img.height - desired_size[1]
|
||||||
w_gap = img.width - desired_size[0]
|
w_gap = img.width - desired_size[0]
|
||||||
assert h_gap >= 0 and w_gap >= 0
|
assert h_gap >= 0 and w_gap >= 0
|
||||||
|
@ -35,18 +37,19 @@ def image_2_tensor(impath, desired_size):
|
||||||
hb = desired_size[1] + ht
|
hb = desired_size[1] + ht
|
||||||
wl = w_gap // 2
|
wl = w_gap // 2
|
||||||
wr = desired_size[1] + wl
|
wr = desired_size[1] + wl
|
||||||
|
'''
|
||||||
|
|
||||||
timg = torchvision.transforms.ToTensor()(img).unsqueeze(0)
|
timg = torchvision.transforms.ToTensor()(img).unsqueeze(0)
|
||||||
|
|
||||||
if desired_size is not None:
|
#if desired_size is not None:
|
||||||
timg = timg[:, :3, ht:hb, wl:wr]
|
# timg = timg[:, :3, ht:hb, wl:wr]
|
||||||
assert timg.shape[2] == desired_size[1] and timg.shape[3] == desired_size[0]
|
# assert timg.shape[2] == desired_size[1] and timg.shape[3] == desired_size[0]
|
||||||
else:
|
#else:
|
||||||
# Enforce that the input must have a input dimension that is a factor of 16.
|
# Enforce that the input must have a input dimension that is a factor of 16.
|
||||||
b, c, h, w = timg.shape
|
b, c, h, w = timg.shape
|
||||||
h = (h // 16) * 16
|
h = (h // 16) * 16
|
||||||
w = (w // 16) * 16
|
w = (w // 16) * 16
|
||||||
timg = timg[:, :3, :h, :w]
|
timg = timg[:, :3, :h, :w]
|
||||||
|
|
||||||
return timg
|
return timg
|
||||||
|
|
||||||
|
@ -88,34 +91,49 @@ def fetch_spatial_metrics_for_latents(latents):
|
||||||
return dt_scales, dt_biases
|
return dt_scales, dt_biases
|
||||||
|
|
||||||
|
|
||||||
def spatial_norm(latents):
|
def spatial_norm(latents, exclusion_list=[]):
|
||||||
nlatents = []
|
nlatents = []
|
||||||
for i in range(len(latents)):
|
for i in range(len(latents)):
|
||||||
latent = latents[i]
|
latent = latents[i]
|
||||||
b, c, h, w = latent.shape
|
if i in exclusion_list:
|
||||||
s = latent.std(dim=[2, 3]).view(1,c,1,1)
|
nlatents.append(latent)
|
||||||
b = latent.mean(dim=[2, 3]).view(1,c,1,1)
|
else:
|
||||||
nlatents.append((latents[i] - b) / s)
|
b, c, h, w = latent.shape
|
||||||
|
s = latent.std(dim=[2, 3]).view(1,c,1,1)
|
||||||
|
b = latent.mean(dim=[2, 3]).view(1,c,1,1)
|
||||||
|
nlatents.append((latents[i] - b) / s)
|
||||||
return nlatents
|
return nlatents
|
||||||
|
|
||||||
|
|
||||||
def local_norm(latents):
|
def local_norm(latents, exclusion_list=[]):
|
||||||
nlatents = []
|
nlatents = []
|
||||||
for i in range(len(latents)):
|
for i in range(len(latents)):
|
||||||
latent = latents[i]
|
latent = latents[i]
|
||||||
b, c, h, w = latent.shape
|
if i in exclusion_list:
|
||||||
s = latent.std(dim=[1]).view(1,1,h,w)
|
nlatents.append(latent)
|
||||||
b = latent.mean(dim=[1]).view(1,1,h,w)
|
else:
|
||||||
nlatents.append((latents[i] - b) / s)
|
b, c, h, w = latent.shape
|
||||||
|
s = latent.std(dim=[1]).view(1,1,h,w)
|
||||||
|
b = latent.mean(dim=[1]).view(1,1,h,w)
|
||||||
|
nlatents.append((latents[i] - b) / s)
|
||||||
return nlatents
|
return nlatents
|
||||||
|
|
||||||
|
|
||||||
|
# Extracts a rectangle of the same shape as <lat> from <ref> and returns it. This is taken from the center of <ref>
|
||||||
|
def extract_center_latent(ref, lat):
|
||||||
|
_, _, h, w = lat.shape
|
||||||
|
_, _, rh, rw = ref.shape
|
||||||
|
dw = (rw - w) / 2
|
||||||
|
dh = (rh - h) / 2
|
||||||
|
return ref[:, :, math.floor(dh):-math.ceil(dh), math.floor(dw):-math.ceil(dw)]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#### options
|
#### options
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
srg_analyze = False
|
srg_analyze = False
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow8x.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../experiments/train_exd_imgset_srflow/train_exd_imgset_srflow.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
utils.util.loaded_options = opt
|
utils.util.loaded_options = opt
|
||||||
|
@ -132,19 +150,19 @@ if __name__ == "__main__":
|
||||||
gen = model.networks['generator']
|
gen = model.networks['generator']
|
||||||
gen.eval()
|
gen.eval()
|
||||||
|
|
||||||
mode = "restore" # restore | latent_transfer | feed_through
|
mode = "feed_through" # restore | latent_transfer | feed_through
|
||||||
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
|
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
|
||||||
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\pure_adrianna_full\\images\\*"
|
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\pure_adrianna_full\\images\\*"
|
||||||
desired_size = None # (640,640) # <- Required when doing style transfer.
|
scale = 2
|
||||||
scale = 8
|
resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents.
|
||||||
resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents.
|
temperature = .3
|
||||||
temperature = 1
|
|
||||||
output_path = "E:\\4k6k\\mmsr\\results\\latent_playground"
|
output_path = "E:\\4k6k\\mmsr\\results\\latent_playground"
|
||||||
|
|
||||||
# Data types <- used to perform latent transfer.
|
# Data types <- used to perform latent transfer.
|
||||||
data_path = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half"
|
data_path = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half"
|
||||||
#data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*x-art-1912*", "*joli_high*", "*stacy-cruz*"]
|
data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*x-art-1912*", "*joli_high*", "*stacy-cruz*"]
|
||||||
data_type_filters = ["*lanette*"]
|
#data_type_filters = ["*lanette*"]
|
||||||
|
max_size = 1100 # Should be set to 2x the largest single dimension of the input space, otherwise an error will occur.
|
||||||
max_ref_datatypes = 30 # Only picks this many images from the above data types to sample from.
|
max_ref_datatypes = 30 # Only picks this many images from the above data types to sample from.
|
||||||
interpolation_steps = 30
|
interpolation_steps = 30
|
||||||
|
|
||||||
|
@ -153,7 +171,7 @@ if __name__ == "__main__":
|
||||||
if mode == "latent_transfer":
|
if mode == "latent_transfer":
|
||||||
# Just get the **one** result for each pattern and use that latent.
|
# Just get the **one** result for each pattern and use that latent.
|
||||||
dt_imgs = [glob(os.path.join(data_path, p))[-5] for p in data_type_filters]
|
dt_imgs = [glob(os.path.join(data_path, p))[-5] for p in data_type_filters]
|
||||||
dt_transfers = [image_2_tensor(i, desired_size) for i in dt_imgs]
|
dt_transfers = [image_2_tensor(i, max_size) for i in dt_imgs]
|
||||||
# Downsample the images because they are often just too big to feed through the network (probably needs to be parameterized)
|
# Downsample the images because they are often just too big to feed through the network (probably needs to be parameterized)
|
||||||
for j in range(len(dt_transfers)):
|
for j in range(len(dt_transfers)):
|
||||||
if min(dt_transfers[j].shape[2], dt_transfers[j].shape[3]) > 1600:
|
if min(dt_transfers[j].shape[2], dt_transfers[j].shape[3]) > 1600:
|
||||||
|
@ -173,7 +191,7 @@ if __name__ == "__main__":
|
||||||
img_files = glob(imgs_to_resample_pattern)
|
img_files = glob(imgs_to_resample_pattern)
|
||||||
random.shuffle(img_files)
|
random.shuffle(img_files)
|
||||||
for im_it, img_file in enumerate(tqdm(img_files)):
|
for im_it, img_file in enumerate(tqdm(img_files)):
|
||||||
t = image_2_tensor(img_file, desired_size).to(model.env['device'])
|
t = image_2_tensor(img_file).to(model.env['device'])
|
||||||
if resample_factor != 1:
|
if resample_factor != 1:
|
||||||
t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic")
|
t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic")
|
||||||
resample_img = t
|
resample_img = t
|
||||||
|
@ -184,6 +202,7 @@ if __name__ == "__main__":
|
||||||
multiple_latents = False
|
multiple_latents = False
|
||||||
if mode == "restore":
|
if mode == "restore":
|
||||||
latents = local_norm(spatial_norm(latents))
|
latents = local_norm(spatial_norm(latents))
|
||||||
|
#latents = spatial_norm(latents)
|
||||||
latents = [l * temperature for l in latents]
|
latents = [l * temperature for l in latents]
|
||||||
elif mode == "feed_through":
|
elif mode == "feed_through":
|
||||||
latents = [torch.randn_like(l) * temperature for l in latents]
|
latents = [torch.randn_like(l) * temperature for l in latents]
|
||||||
|
@ -192,7 +211,7 @@ if __name__ == "__main__":
|
||||||
for slat in dt_latents:
|
for slat in dt_latents:
|
||||||
assert slat[0].shape[2] >= latents[0].shape[2]
|
assert slat[0].shape[2] >= latents[0].shape[2]
|
||||||
assert slat[0].shape[3] >= latents[0].shape[3]
|
assert slat[0].shape[3] >= latents[0].shape[3]
|
||||||
dts.append([sl[:,:,:l.shape[2],:l.shape[3]] * temperature for l, sl in zip(latents, slat)])
|
dts.append([extract_center_latent(sl, l) * temperature for l, sl in zip(latents, slat)])
|
||||||
latents = dts
|
latents = dts
|
||||||
multiple_latents = True
|
multiple_latents = True
|
||||||
|
|
||||||
|
@ -201,7 +220,6 @@ if __name__ == "__main__":
|
||||||
lats = [latents]
|
lats = [latents]
|
||||||
else:
|
else:
|
||||||
lats = latents
|
lats = latents
|
||||||
torchvision.utils.save_image(resample_img, os.path.join(output_path, "%i_orig.jpg" %(im_it)))
|
|
||||||
for j in range(len(lats)):
|
for j in range(len(lats)):
|
||||||
hr, _ = gen(lr=F.interpolate(resample_img, scale_factor=1/scale, mode="area"),
|
hr, _ = gen(lr=F.interpolate(resample_img, scale_factor=1/scale, mode="area"),
|
||||||
z=lats[j][0],
|
z=lats[j][0],
|
||||||
|
@ -211,4 +229,5 @@ if __name__ == "__main__":
|
||||||
if torch.isnan(torch.max(hr)):
|
if torch.isnan(torch.max(hr)):
|
||||||
continue
|
continue
|
||||||
os.makedirs(os.path.join(output_path), exist_ok=True)
|
os.makedirs(os.path.join(output_path), exist_ok=True)
|
||||||
|
torchvision.utils.save_image(resample_img, os.path.join(output_path, "%i_orig.jpg" %(im_it)))
|
||||||
torchvision.utils.save_image(hr, os.path.join(output_path, "%i_%i.jpg" % (im_it,j)))
|
torchvision.utils.save_image(hr, os.path.join(output_path, "%i_%i.jpg" % (im_it,j)))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user