forked from mrq/DL-Art-School
srflow latent space misc
This commit is contained in:
parent
1e14635d88
commit
e5a3e6b9b5
|
@ -143,7 +143,8 @@ def slerp(latents1, latents2, proportion):
|
|||
return res
|
||||
|
||||
|
||||
def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=10):
|
||||
def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=10, prefix=''):
|
||||
# Outputs a series of images interpolated from [latents1] to [latents2]. image 0 biases towards latents2.
|
||||
for i in range(steps):
|
||||
proportion = i / (steps-1)
|
||||
lats = linear_interpolation(latents1, latents2, proportion)
|
||||
|
@ -153,7 +154,7 @@ def create_interpolation_video(gen, lq, output_file, latents1, latents2, steps=1
|
|||
epses=lats,
|
||||
add_gt_noise=False)
|
||||
torchvision.transforms.ToPILImage()(hr.squeeze(0).cpu())
|
||||
torchvision.utils.save_image(hr.cpu(), os.path.join(output_file, "%i.png" % (i,)))
|
||||
torchvision.utils.save_image(hr.cpu(), os.path.join(output_file, "%s_%i.png" % (prefix, i)))
|
||||
# Stopped using this because PILs animated gif output is total crap.
|
||||
#images[0].save(output_file, save_all=True, append_images=images[1:], duration=80, loop=0)
|
||||
|
||||
|
@ -163,7 +164,7 @@ if __name__ == "__main__":
|
|||
torch.backends.cudnn.benchmark = True
|
||||
srg_analyze = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow_frompsnr.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow_bigboi_frompsnr.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
@ -185,15 +186,15 @@ if __name__ == "__main__":
|
|||
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
|
||||
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
|
||||
scale = 2
|
||||
resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents.
|
||||
resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. E.g. set this to '2' to get 2x upsampling.
|
||||
temperature = 1
|
||||
output_path = "E:\\4k6k\\mmsr\\results\\latent_playground"
|
||||
output_path = "..\\..\\results\\latent_playground"
|
||||
|
||||
# Data types <- used to perform latent transfer.
|
||||
data_path = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half"
|
||||
data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*joli_high*"]
|
||||
#data_type_filters = ["*lanette*"]
|
||||
max_size = 1100 # Should be set to 2x the largest single dimension of the input space, otherwise an error will occur.
|
||||
#data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*joli_high*"]
|
||||
data_type_filters = ["*lanette*"]
|
||||
max_size = 1600 # Should be set to 2x the largest single dimension of the input space, otherwise an error will occur.
|
||||
max_ref_datatypes = 30 # Only picks this many images from the above data types to sample from.
|
||||
interpolation_steps = 30
|
||||
|
||||
|
@ -204,9 +205,9 @@ if __name__ == "__main__":
|
|||
dt_imgs = [glob(os.path.join(data_path, p))[-5] for p in data_type_filters]
|
||||
dt_transfers = [image_2_tensor(i, max_size) for i in dt_imgs]
|
||||
# Downsample the images because they are often just too big to feed through the network (probably needs to be parameterized)
|
||||
for j in range(len(dt_transfers)):
|
||||
if min(dt_transfers[j].shape[2], dt_transfers[j].shape[3]) > 1600:
|
||||
dt_transfers[j] = F.interpolate(dt_transfers[j], scale_factor=1 / 2, mode='area')
|
||||
#for j in range(len(dt_transfers)):
|
||||
# if min(dt_transfers[j].shape[2], dt_transfers[j].shape[3]) > 1600:
|
||||
# dt_transfers[j] = F.interpolate(dt_transfers[j], scale_factor=1 / 2, mode='area')
|
||||
corruptor = ImageCorruptor({'fixed_corruptions': ['jpeg-medium', 'gaussian_blur_3']})
|
||||
def corrupt_and_downsample(img, scale):
|
||||
img = F.interpolate(img, scale_factor=1 / scale, mode="area")
|
||||
|
@ -219,11 +220,13 @@ if __name__ == "__main__":
|
|||
|
||||
# Fetch the images to resample.
|
||||
img_files = glob(imgs_to_resample_pattern)
|
||||
random.shuffle(img_files)
|
||||
#random.shuffle(img_files)
|
||||
for im_it, img_file in enumerate(tqdm(img_files)):
|
||||
t = image_2_tensor(img_file).to(model.env['device'])
|
||||
if resample_factor != 1:
|
||||
if resample_factor > 1:
|
||||
t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic")
|
||||
elif resample_factor < 1:
|
||||
t = F.interpolate(t, scale_factor=resample_factor, mode="area")
|
||||
# Ensure the input image is a factor of 16.
|
||||
_, _, h, w = t.shape
|
||||
h = 16 * (h // 16)
|
||||
|
@ -260,6 +263,6 @@ if __name__ == "__main__":
|
|||
for j in range(len(lats)):
|
||||
path = os.path.join(output_path, "%i_%i" % (im_it, j))
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torchvision.utils.save_image(resample_img, os.path.join(path, "orig.jpg" %(im_it)))
|
||||
torchvision.utils.save_image(resample_img, os.path.join(path, "orig_%i.jpg" % (im_it)))
|
||||
create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"),
|
||||
path, [torch.zeros_like(l) for l in lats[j]], lats[j])
|
||||
path, [torch.zeros_like(l) for l in lats[j]], lats[j], prefix=mode)
|
||||
|
|
|
@ -219,6 +219,7 @@ class Trainer:
|
|||
for b in range(len(val_data['GT_path'])):
|
||||
img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0]
|
||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||
|
||||
util.mkdir(img_dir)
|
||||
|
||||
self.model.feed_data(val_data, self.current_step)
|
||||
|
@ -292,7 +293,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_sbyol_512unsupervised.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb_bigboi_psnr_4x.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user