2020-11-25 20:59:06 +00:00
import argparse
import logging
import math
import os
from glob import glob
import torch
import torch . nn . functional as F
import torchvision
from PIL import Image
from tqdm import tqdm
import utils . options as option
import utils
from data . image_corruptor import ImageCorruptor
2020-12-18 16:18:34 +00:00
from trainer . ExtensibleTrainer import ExtensibleTrainer
2020-11-25 20:59:06 +00:00
from utils import util
2020-11-27 19:03:16 +00:00
def image_2_tensor ( impath , max_size = None ) :
2020-11-25 20:59:06 +00:00
img = Image . open ( impath )
2020-11-27 19:03:16 +00:00
if max_size is not None :
factor = min ( max_size / img . width , max_size / img . height )
2020-11-25 20:59:06 +00:00
new_size = ( int ( math . ceil ( img . width * factor ) ) , int ( math . ceil ( img . height * factor ) ) )
2020-11-27 19:03:16 +00:00
img = img . resize ( new_size , Image . LANCZOS )
2020-11-25 20:59:06 +00:00
2020-11-27 19:03:16 +00:00
'''
# Useful for setting an image to an exact size.
2020-11-25 20:59:06 +00:00
h_gap = img . height - desired_size [ 1 ]
w_gap = img . width - desired_size [ 0 ]
assert h_gap > = 0 and w_gap > = 0
ht = h_gap / / 2
hb = desired_size [ 1 ] + ht
wl = w_gap / / 2
wr = desired_size [ 1 ] + wl
2020-11-27 19:03:16 +00:00
'''
2020-11-25 20:59:06 +00:00
timg = torchvision . transforms . ToTensor ( ) ( img ) . unsqueeze ( 0 )
2020-11-27 19:03:16 +00:00
#if desired_size is not None:
# timg = timg[:, :3, ht:hb, wl:wr]
# assert timg.shape[2] == desired_size[1] and timg.shape[3] == desired_size[0]
#else:
# Enforce that the input must have a input dimension that is a factor of 16.
b , c , h , w = timg . shape
h = ( h / / 16 ) * 16
w = ( w / / 16 ) * 16
timg = timg [ : , : 3 , : h , : w ]
2020-11-25 20:59:06 +00:00
return timg
def interpolate_lr ( hr , scale ) :
return F . interpolate ( hr , scale_factor = 1 / scale , mode = " area " )
def fetch_latents_for_image ( gen , img , scale , lr_infer = interpolate_lr ) :
z , _ , _ = gen ( gt = img ,
lr = lr_infer ( img , scale ) ,
epses = [ ] ,
reverse = False ,
add_gt_noise = False )
return z
def fetch_latents_for_images ( gen , imgs , scale , lr_infer = interpolate_lr ) :
latents = [ ]
2020-11-27 03:31:21 +00:00
for img in imgs :
2020-11-25 20:59:06 +00:00
z , _ , _ = gen ( gt = img ,
lr = lr_infer ( img , scale ) ,
epses = [ ] ,
reverse = False ,
add_gt_noise = False )
latents . append ( z )
return latents
def fetch_spatial_metrics_for_latents ( latents ) :
dt_scales = [ ]
dt_biases = [ ]
for i in range ( len ( latents ) ) :
latent = torch . stack ( latents [ i ] , dim = - 1 ) . squeeze ( 0 )
s = latent . std ( dim = [ 1 , 2 , 3 ] ) . view ( 1 , - 1 , 1 , 1 )
b = latent . mean ( dim = [ 1 , 2 , 3 ] ) . view ( 1 , - 1 , 1 , 1 )
dt_scales . append ( s )
dt_biases . append ( b )
return dt_scales , dt_biases
2020-11-27 19:03:16 +00:00
def spatial_norm ( latents , exclusion_list = [ ] ) :
2020-11-25 20:59:06 +00:00
nlatents = [ ]
for i in range ( len ( latents ) ) :
latent = latents [ i ]
2020-11-27 19:03:16 +00:00
if i in exclusion_list :
nlatents . append ( latent )
else :
b , c , h , w = latent . shape
s = latent . std ( dim = [ 2 , 3 ] ) . view ( 1 , c , 1 , 1 )
b = latent . mean ( dim = [ 2 , 3 ] ) . view ( 1 , c , 1 , 1 )
nlatents . append ( ( latents [ i ] - b ) / s )
2020-11-25 20:59:06 +00:00
return nlatents
2020-11-27 19:03:16 +00:00
def local_norm ( latents , exclusion_list = [ ] ) :
2020-11-25 20:59:06 +00:00
nlatents = [ ]
for i in range ( len ( latents ) ) :
latent = latents [ i ]
2020-11-27 19:03:16 +00:00
if i in exclusion_list :
nlatents . append ( latent )
else :
b , c , h , w = latent . shape
s = latent . std ( dim = [ 1 ] ) . view ( 1 , 1 , h , w )
b = latent . mean ( dim = [ 1 ] ) . view ( 1 , 1 , h , w )
nlatents . append ( ( latents [ i ] - b ) / s )
2020-11-25 20:59:06 +00:00
return nlatents
2020-11-27 19:03:16 +00:00
# Extracts a rectangle of the same shape as <lat> from <ref> and returns it. This is taken from the center of <ref>
def extract_center_latent ( ref , lat ) :
_ , _ , h , w = lat . shape
_ , _ , rh , rw = ref . shape
dw = ( rw - w ) / 2
dh = ( rh - h ) / 2
return ref [ : , : , math . floor ( dh ) : - math . ceil ( dh ) , math . floor ( dw ) : - math . ceil ( dw ) ]
2020-11-29 16:33:29 +00:00
def linear_interpolation ( latents1 , latents2 , proportion ) :
return [ l1 * proportion + l2 * ( 1 - proportion ) for l1 , l2 in zip ( latents1 , latents2 ) ]
def slerp ( latents1 , latents2 , proportion ) :
res = [ ]
for low , high in zip ( latents1 , latents2 ) :
low_norm = low / torch . norm ( low , dim = [ 2 , 3 ] , keepdim = True )
high_norm = high / torch . norm ( high , dim = [ 2 , 3 ] , keepdim = True )
omega = torch . acos ( ( low_norm * high_norm ) . sum ( 1 ) )
so = torch . sin ( omega )
res . append ( ( torch . sin ( ( 1.0 - proportion ) * omega ) / so ) . unsqueeze ( 1 ) * low + ( torch . sin ( proportion * omega ) / so ) . unsqueeze ( 1 ) * high )
return res
2020-12-15 06:59:49 +00:00
def create_interpolation_video ( gen , lq , output_file , latents1 , latents2 , steps = 10 , prefix = ' ' ) :
# Outputs a series of images interpolated from [latents1] to [latents2]. image 0 biases towards latents2.
2020-11-29 16:33:29 +00:00
for i in range ( steps ) :
proportion = i / ( steps - 1 )
lats = linear_interpolation ( latents1 , latents2 , proportion )
2020-12-22 22:42:38 +00:00
hr , _ , _ = gen ( lr = lq ,
2020-11-29 16:33:29 +00:00
z = lats [ 0 ] ,
reverse = True ,
epses = lats ,
add_gt_noise = False )
torchvision . transforms . ToPILImage ( ) ( hr . squeeze ( 0 ) . cpu ( ) )
2020-12-15 06:59:49 +00:00
torchvision . utils . save_image ( hr . cpu ( ) , os . path . join ( output_file , " %s _ %i .png " % ( prefix , i ) ) )
2020-11-29 16:33:29 +00:00
# Stopped using this because PILs animated gif output is total crap.
#images[0].save(output_file, save_all=True, append_images=images[1:], duration=80, loop=0)
2020-11-25 20:59:06 +00:00
if __name__ == " __main__ " :
#### options
torch . backends . cudnn . benchmark = True
srg_analyze = False
parser = argparse . ArgumentParser ( )
2020-12-15 06:59:49 +00:00
parser . add_argument ( ' -opt ' , type = str , help = ' Path to options YAML file. ' , default = ' ../../options/train_exd_imgsetext_srflow_bigboi_frompsnr.yml ' )
2020-11-25 20:59:06 +00:00
opt = option . parse ( parser . parse_args ( ) . opt , is_train = False )
opt = option . dict_to_nonedict ( opt )
utils . util . loaded_options = opt
util . mkdirs (
( path for key , path in opt [ ' path ' ] . items ( )
if not key == ' experiments_root ' and ' pretrain_model ' not in key and ' resume ' not in key ) )
util . setup_logger ( ' base ' , opt [ ' path ' ] [ ' log ' ] , ' test_ ' + opt [ ' name ' ] , level = logging . INFO ,
screen = True , tofile = True )
logger = logging . getLogger ( ' base ' )
logger . info ( option . dict2str ( opt ) )
model = ExtensibleTrainer ( opt )
gen = model . networks [ ' generator ' ]
gen . eval ( )
2020-12-22 22:42:38 +00:00
mode = " feed_through " # temperature | restore | latent_transfer | feed_through
imgs_to_resample_pattern = " F: \\ 4k6k \\ datasets \\ ns_images \\ adrianna \\ val2 \\ lr \\ * "
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
2020-12-05 03:14:53 +00:00
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
2020-11-27 19:03:16 +00:00
scale = 2
2020-12-15 06:59:49 +00:00
resample_factor = 2 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. E.g. set this to '2' to get 2x upsampling.
2020-11-29 16:33:29 +00:00
temperature = 1
2020-12-15 06:59:49 +00:00
output_path = " .. \\ .. \\ results \\ latent_playground "
2020-11-25 20:59:06 +00:00
# Data types <- used to perform latent transfer.
data_path = " F: \\ 4k6k \\ datasets \\ ns_images \\ imagesets \\ images-half "
2020-12-15 06:59:49 +00:00
#data_type_filters = ["*alexa*", "*lanette*", "*80755*", "*joli_high*"]
data_type_filters = [ " *lanette* " ]
max_size = 1600 # Should be set to 2x the largest single dimension of the input space, otherwise an error will occur.
2020-11-25 20:59:06 +00:00
max_ref_datatypes = 30 # Only picks this many images from the above data types to sample from.
interpolation_steps = 30
with torch . no_grad ( ) :
2020-11-27 03:31:21 +00:00
# Compute latent variables for the reference images.
if mode == " latent_transfer " :
2020-11-25 20:59:06 +00:00
# Just get the **one** result for each pattern and use that latent.
dt_imgs = [ glob ( os . path . join ( data_path , p ) ) [ - 5 ] for p in data_type_filters ]
2020-11-27 19:03:16 +00:00
dt_transfers = [ image_2_tensor ( i , max_size ) for i in dt_imgs ]
2020-11-25 20:59:06 +00:00
# Downsample the images because they are often just too big to feed through the network (probably needs to be parameterized)
2020-12-15 06:59:49 +00:00
#for j in range(len(dt_transfers)):
# if min(dt_transfers[j].shape[2], dt_transfers[j].shape[3]) > 1600:
# dt_transfers[j] = F.interpolate(dt_transfers[j], scale_factor=1 / 2, mode='area')
2020-11-27 03:31:21 +00:00
corruptor = ImageCorruptor ( { ' fixed_corruptions ' : [ ' jpeg-medium ' , ' gaussian_blur_3 ' ] } )
2020-11-25 20:59:06 +00:00
def corrupt_and_downsample ( img , scale ) :
2020-11-27 03:31:21 +00:00
img = F . interpolate ( img , scale_factor = 1 / scale , mode = " area " )
2020-11-25 20:59:06 +00:00
from data . util import torch2cv , cv2torch
cvimg = torch2cv ( img )
cvimg = corruptor . corrupt_images ( [ cvimg ] ) [ 0 ]
img = cv2torch ( cvimg )
return img
dt_latents = [ fetch_latents_for_image ( gen , i , scale , corrupt_and_downsample ) for i in dt_transfers ]
2020-11-27 03:31:21 +00:00
# Fetch the images to resample.
img_files = glob ( imgs_to_resample_pattern )
2020-12-15 06:59:49 +00:00
#random.shuffle(img_files)
2020-11-27 03:31:21 +00:00
for im_it , img_file in enumerate ( tqdm ( img_files ) ) :
2020-11-27 19:03:16 +00:00
t = image_2_tensor ( img_file ) . to ( model . env [ ' device ' ] )
2020-12-15 06:59:49 +00:00
if resample_factor > 1 :
2020-11-27 03:31:21 +00:00
t = F . interpolate ( t , scale_factor = resample_factor , mode = " bicubic " )
2020-12-15 06:59:49 +00:00
elif resample_factor < 1 :
t = F . interpolate ( t , scale_factor = resample_factor , mode = " area " )
2020-12-05 03:14:53 +00:00
# Ensure the input image is a factor of 16.
_ , _ , h , w = t . shape
h = 16 * ( h / / 16 )
w = 16 * ( w / / 16 )
t = t [ : , : , : h , : w ]
2020-11-27 03:31:21 +00:00
resample_img = t
# Fetch the latent metrics & latents for each image we are resampling.
latents = fetch_latents_for_images ( gen , [ resample_img ] , scale ) [ 0 ]
multiple_latents = False
if mode == " restore " :
latents = local_norm ( spatial_norm ( latents ) )
2020-11-27 19:03:16 +00:00
#latents = spatial_norm(latents)
2020-11-27 03:31:21 +00:00
latents = [ l * temperature for l in latents ]
elif mode == " feed_through " :
latents = [ torch . randn_like ( l ) * temperature for l in latents ]
elif mode == " latent_transfer " :
2020-11-25 20:59:06 +00:00
dts = [ ]
for slat in dt_latents :
2020-11-27 03:31:21 +00:00
assert slat [ 0 ] . shape [ 2 ] > = latents [ 0 ] . shape [ 2 ]
assert slat [ 0 ] . shape [ 3 ] > = latents [ 0 ] . shape [ 3 ]
2020-11-27 19:03:16 +00:00
dts . append ( [ extract_center_latent ( sl , l ) * temperature for l , sl in zip ( latents , slat ) ] )
2020-11-27 03:31:21 +00:00
latents = dts
multiple_latents = True
2020-11-29 16:33:29 +00:00
elif mode == " temperature " :
latents = [ l * temperature for l in latents ]
2020-11-27 03:31:21 +00:00
# Re-compute each image with the new metrics
2020-11-25 20:59:06 +00:00
if not multiple_latents :
2020-11-27 03:31:21 +00:00
lats = [ latents ]
2020-11-25 20:59:06 +00:00
else :
2020-11-27 03:31:21 +00:00
lats = latents
2020-11-25 20:59:06 +00:00
for j in range ( len ( lats ) ) :
2020-11-29 16:33:29 +00:00
path = os . path . join ( output_path , " %i _ %i " % ( im_it , j ) )
os . makedirs ( path , exist_ok = True )
2020-12-15 06:59:49 +00:00
torchvision . utils . save_image ( resample_img , os . path . join ( path , " orig_ %i .jpg " % ( im_it ) ) )
2020-11-29 16:33:29 +00:00
create_interpolation_video ( gen , F . interpolate ( resample_img , scale_factor = 1 / scale , mode = " area " ) ,
2020-12-15 06:59:49 +00:00
path , [ torch . zeros_like ( l ) for l in lats [ j ] ] , lats [ j ] , prefix = mode )