forked from mrq/DL-Art-School
Fix srflow_latent_space_playground bug
This commit is contained in:
parent
e7aeb17404
commit
2437b33e74
|
@ -146,7 +146,7 @@ def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=1
|
|||
for i in range(steps):
|
||||
proportion = i / (steps-1)
|
||||
lats = linear_interpolation(latents1, latents2, proportion)
|
||||
hr, _ = gen(lr=lq,
|
||||
hr, _, _ = gen(lr=lq,
|
||||
z=lats[0],
|
||||
reverse=True,
|
||||
epses=lats,
|
||||
|
@ -179,9 +179,9 @@ if __name__ == "__main__":
|
|||
gen = model.networks['generator']
|
||||
gen.eval()
|
||||
|
||||
mode = "restore" # temperature | restore | latent_transfer | feed_through
|
||||
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
|
||||
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
|
||||
mode = "feed_through" # temperature | restore | latent_transfer | feed_through
|
||||
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
|
||||
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
|
||||
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
|
||||
scale = 2
|
||||
resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. E.g. set this to '2' to get 2x upsampling.
|
||||
|
|
|
@ -48,7 +48,7 @@ if __name__ == "__main__":
|
|||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_4x_psnr.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_faces_glean.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
|
|
@ -216,8 +216,8 @@ class Trainer:
|
|||
val_tqdm = tqdm(self.val_loader)
|
||||
for val_data in val_tqdm:
|
||||
idx += 1
|
||||
for b in range(len(val_data['GT_path'])):
|
||||
img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0]
|
||||
for b in range(len(val_data['HQ_path'])):
|
||||
img_name = os.path.splitext(os.path.basename(val_data['HQ_path'][b]))[0]
|
||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||
|
||||
util.mkdir(img_dir)
|
||||
|
|
Loading…
Reference in New Issue
Block a user