Fix srflow_latent_space_playground bug
This commit is contained in:
parent
e7aeb17404
commit
2437b33e74
|
@ -146,7 +146,7 @@ def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=1
|
||||||
for i in range(steps):
|
for i in range(steps):
|
||||||
proportion = i / (steps-1)
|
proportion = i / (steps-1)
|
||||||
lats = linear_interpolation(latents1, latents2, proportion)
|
lats = linear_interpolation(latents1, latents2, proportion)
|
||||||
hr, _ = gen(lr=lq,
|
hr, _, _ = gen(lr=lq,
|
||||||
z=lats[0],
|
z=lats[0],
|
||||||
reverse=True,
|
reverse=True,
|
||||||
epses=lats,
|
epses=lats,
|
||||||
|
@ -179,9 +179,9 @@ if __name__ == "__main__":
|
||||||
gen = model.networks['generator']
|
gen = model.networks['generator']
|
||||||
gen.eval()
|
gen.eval()
|
||||||
|
|
||||||
mode = "restore" # temperature | restore | latent_transfer | feed_through
|
mode = "feed_through" # temperature | restore | latent_transfer | feed_through
|
||||||
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
|
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
|
||||||
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
|
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
|
||||||
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
|
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
|
||||||
scale = 2
|
scale = 2
|
||||||
resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. E.g. set this to '2' to get 2x upsampling.
|
resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. E.g. set this to '2' to get 2x upsampling.
|
||||||
|
|
|
@ -48,7 +48,7 @@ if __name__ == "__main__":
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
want_metrics = False
|
want_metrics = False
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_4x_psnr.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_faces_glean.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
utils.util.loaded_options = opt
|
utils.util.loaded_options = opt
|
||||||
|
|
|
@ -216,8 +216,8 @@ class Trainer:
|
||||||
val_tqdm = tqdm(self.val_loader)
|
val_tqdm = tqdm(self.val_loader)
|
||||||
for val_data in val_tqdm:
|
for val_data in val_tqdm:
|
||||||
idx += 1
|
idx += 1
|
||||||
for b in range(len(val_data['GT_path'])):
|
for b in range(len(val_data['HQ_path'])):
|
||||||
img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0]
|
img_name = os.path.splitext(os.path.basename(val_data['HQ_path'][b]))[0]
|
||||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||||
|
|
||||||
util.mkdir(img_dir)
|
util.mkdir(img_dir)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user