srflow latent space misc

This commit is contained in:
James Betker 2020-12-14 23:59:49 -07:00
parent 1e14635d88
commit e5a3e6b9b5
2 changed files with 20 additions and 16 deletions

View File

@ -143,7 +143,8 @@ def slerp(latents1, latents2, proportion):
return res
def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=10):
def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=10, prefix=''):
# Outputs a series of images interpolated from [latents1] to [latents2]. image 0 biases towards latents2.
for i in range(steps):
proportion = i / (steps-1)
lats = linear_interpolation(latents1, latents2, proportion)
@ -153,7 +154,7 @@ def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=1
epses=lats,
add_gt_noise=False)
torchvision.transforms.ToPILImage()(hr.squeeze(0).cpu())
torchvision.utils.save_image(hr.cpu(), os.path.join(output_file, "%i.png" % (i,)))
torchvision.utils.save_image(hr.cpu(), os.path.join(output_file, "%s_%i.png" % (prefix, i)))
# Stopped using this because PILs animated gif output is total crap.
#images[0].save(output_file, save_all=True, append_images=images[1:], duration=80, loop=0)
@ -163,7 +164,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
srg_analyze = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow_frompsnr.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow_bigboi_frompsnr.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
@ -185,15 +186,15 @@ if __name__ == "__main__":
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
scale = 2
resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents.
resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. E.g. set this to '2' to get 2x upsampling.
temperature = 1
output_path = "E:\\4k6k\\mmsr\\results\\latent_playground"
output_path = "..\\..\\results\\latent_playground"
# Data types <- used to perform latent transfer.
data_path = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half"
data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*joli_high*"]
#data_type_filters = ["*lanette*"]
max_size = 1100 # Should be set to 2x the largest single dimension of the input space, otherwise an error will occur.
#data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*joli_high*"]
data_type_filters = ["*lanette*"]
max_size = 1600 # Should be set to 2x the largest single dimension of the input space, otherwise an error will occur.
max_ref_datatypes = 30 # Only picks this many images from the above data types to sample from.
interpolation_steps = 30
@ -204,9 +205,9 @@ if __name__ == "__main__":
dt_imgs = [glob(os.path.join(data_path, p))[-5] for p in data_type_filters]
dt_transfers = [image_2_tensor(i, max_size) for i in dt_imgs]
# Downsample the images because they are often just too big to feed through the network (probably needs to be parameterized)
for j in range(len(dt_transfers)):
if min(dt_transfers[j].shape[2], dt_transfers[j].shape[3]) > 1600:
dt_transfers[j] = F.interpolate(dt_transfers[j], scale_factor=1 / 2, mode='area')
#for j in range(len(dt_transfers)):
# if min(dt_transfers[j].shape[2], dt_transfers[j].shape[3]) > 1600:
# dt_transfers[j] = F.interpolate(dt_transfers[j], scale_factor=1 / 2, mode='area')
corruptor = ImageCorruptor({'fixed_corruptions': ['jpeg-medium', 'gaussian_blur_3']})
def corrupt_and_downsample(img, scale):
img = F.interpolate(img, scale_factor=1 / scale, mode="area")
@ -219,11 +220,13 @@ if __name__ == "__main__":
# Fetch the images to resample.
img_files = glob(imgs_to_resample_pattern)
random.shuffle(img_files)
#random.shuffle(img_files)
for im_it, img_file in enumerate(tqdm(img_files)):
t = image_2_tensor(img_file).to(model.env['device'])
if resample_factor != 1:
if resample_factor > 1:
t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic")
elif resample_factor < 1:
t = F.interpolate(t, scale_factor=resample_factor, mode="area")
# Ensure the input image is a factor of 16.
_, _, h, w = t.shape
h = 16 * (h // 16)
@ -260,6 +263,6 @@ if __name__ == "__main__":
for j in range(len(lats)):
path = os.path.join(output_path, "%i_%i" % (im_it, j))
os.makedirs(path, exist_ok=True)
torchvision.utils.save_image(resample_img, os.path.join(path, "orig.jpg" %(im_it)))
torchvision.utils.save_image(resample_img, os.path.join(path, "orig_%i.jpg" % (im_it)))
create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"),
path, [torch.zeros_like(l) for l in lats[j]], lats[j])
path, [torch.zeros_like(l) for l in lats[j]], lats[j], prefix=mode)

View File

@ -219,6 +219,7 @@ class Trainer:
for b in range(len(val_data['GT_path'])):
img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0]
img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir)
self.model.feed_data(val_data, self.current_step)
@ -292,7 +293,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_sbyol_512unsupervised.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb_bigboi_psnr_4x.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()