Fix srflow_latent_space_playground bug

This commit is contained in:
James Betker 2020-12-22 15:42:38 -07:00
parent e7aeb17404
commit 2437b33e74
3 changed files with 7 additions and 7 deletions

View File

@ -146,7 +146,7 @@ def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=1
for i in range(steps):
proportion = i / (steps-1)
lats = linear_interpolation(latents1, latents2, proportion)
hr, _ = gen(lr=lq,
hr, _, _ = gen(lr=lq,
z=lats[0],
reverse=True,
epses=lats,
@ -179,9 +179,9 @@ if __name__ == "__main__":
gen = model.networks['generator']
gen.eval()
mode = "restore" # temperature | restore | latent_transfer | feed_through
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
mode = "feed_through" # temperature | restore | latent_transfer | feed_through
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
scale = 2
resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. E.g. set this to '2' to get 2x upsampling.

View File

@ -48,7 +48,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_4x_psnr.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_faces_glean.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt

View File

@ -216,8 +216,8 @@ class Trainer:
val_tqdm = tqdm(self.val_loader)
for val_data in val_tqdm:
idx += 1
for b in range(len(val_data['GT_path'])):
img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0]
for b in range(len(val_data['HQ_path'])):
img_name = os.path.splitext(os.path.basename(val_data['HQ_path'][b]))[0]
img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir)