2020-10-07 15:03:30 +00:00
from models . steps . losses import ConfigurableLoss , GANLoss , extract_params_from_state , get_basic_criterion_for_name
2020-10-06 01:35:28 +00:00
from models . layers . resample2d_package . resample2d import Resample2d
from models . steps . recurrent import RecurrentController
from models . steps . injectors import Injector
import torch
2020-10-07 15:03:30 +00:00
import torch . nn . functional as F
2020-10-07 02:35:39 +00:00
import os
import os . path as osp
import torchvision
2020-10-06 01:35:28 +00:00
def create_teco_loss ( opt , env ) :
type = opt [ ' type ' ]
2020-10-07 15:03:30 +00:00
if type == ' teco_gan ' :
return TecoGanLoss ( opt , env )
2020-10-06 01:35:28 +00:00
elif type == " teco_pingpong " :
return PingPongLoss ( opt , env )
return None
2020-10-07 15:03:30 +00:00
def create_teco_injector ( opt , env ) :
type = opt [ ' type ' ]
if type == ' teco_recurrent_generated_sequence_injector ' :
return RecurrentImageGeneratorSequenceInjector ( opt , env )
return None
2020-10-06 01:35:28 +00:00
2020-10-07 15:03:30 +00:00
def create_teco_discriminator_sextuplet ( input_list , lr_imgs , scale , index , flow_gen , resampler ) :
triplet = input_list [ : , index : index + 3 ]
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
with torch . no_grad ( ) :
first_flow = flow_gen ( torch . stack ( [ lr_imgs [ : , 0 ] , lr_imgs [ : , 1 ] ] , dim = 2 ) )
first_flow = F . interpolate ( first_flow , scale_factor = scale , mode = ' bicubic ' )
last_flow = flow_gen ( torch . stack ( [ lr_imgs [ : , 2 ] , lr_imgs [ : , 1 ] ] , dim = 2 ) )
last_flow = F . interpolate ( last_flow , scale_factor = scale , mode = ' bicubic ' )
flow_triplet = [ resampler ( triplet [ : , 0 ] . float ( ) , first_flow . float ( ) ) ,
triplet [ : , 1 ] ,
resampler ( triplet [ : , 2 ] . float ( ) , last_flow . float ( ) ) ]
flow_triplet = torch . stack ( flow_triplet , dim = 2 )
combined = torch . cat ( [ triplet , flow_triplet ] , dim = 2 )
b , f , c , h , w = combined . shape
return combined . view ( b , 3 * 6 , h , w ) # 3*6 is essentially an assertion here.
def extract_inputs_index ( inputs , i ) :
res = [ ]
for input in inputs :
if isinstance ( input , torch . Tensor ) :
res . append ( input [ : , i ] )
else :
res . append ( input )
return res
2020-10-06 01:35:28 +00:00
# Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out]
# Images are fed in sequentially forward and back, resulting in len([out])=2*len([in])-1 (last element is not repeated).
# All computation is done with torch.no_grad().
class RecurrentImageGeneratorSequenceInjector ( Injector ) :
def __init__ ( self , opt , env ) :
super ( RecurrentImageGeneratorSequenceInjector , self ) . __init__ ( opt , env )
self . flow = opt [ ' flow_network ' ]
2020-10-07 15:03:30 +00:00
self . input_lq_index = opt [ ' input_lq_index ' ] if ' input_lq_index ' in opt . keys ( ) else 0
self . output_hq_index = opt [ ' output_hq_index ' ] if ' output_hq_index ' in opt . keys ( ) else 0
self . scale = opt [ ' scale ' ]
2020-10-06 01:35:28 +00:00
self . resample = Resample2d ( )
def forward ( self , state ) :
gen = self . env [ ' generators ' ] [ self . opt [ ' generator ' ] ]
flow = self . env [ ' generators ' ] [ self . flow ]
results = [ ]
2020-10-07 15:03:30 +00:00
inputs = extract_params_from_state ( self . input , state )
if not isinstance ( inputs , list ) :
inputs = [ inputs ]
recurrent_input = torch . zeros_like ( inputs [ self . input_lq_index ] [ : , 0 ] )
2020-10-06 01:35:28 +00:00
# Go forward in the sequence first.
first_step = True
2020-10-07 15:03:30 +00:00
b , f , c , h , w = inputs [ self . input_lq_index ] . shape
for i in range ( f ) :
input = extract_inputs_index ( inputs , i )
2020-10-06 01:35:28 +00:00
if first_step :
first_step = False
else :
2020-10-07 15:03:30 +00:00
with torch . no_grad ( ) :
reduced_recurrent = F . interpolate ( recurrent_input , scale_factor = 1 / self . scale , mode = ' bicubic ' )
flow_input = torch . stack ( [ input [ self . input_lq_index ] , reduced_recurrent ] , dim = 2 )
flowfield = flow ( flow_input )
# Resample does not work in FP16.
recurrent_input = self . resample ( reduced_recurrent . float ( ) , flowfield . float ( ) )
input [ self . input_lq_index ] = torch . cat ( [ input [ self . input_lq_index ] , recurrent_input ] , dim = 1 )
gen_out = gen ( * input )
recurrent_input = gen_out [ self . output_hq_index ]
2020-10-06 01:35:28 +00:00
results . append ( recurrent_input )
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
2020-10-07 15:03:30 +00:00
it = reversed ( range ( f - 1 ) )
2020-10-06 01:35:28 +00:00
for i in it :
2020-10-07 15:03:30 +00:00
input = extract_inputs_index ( inputs , i )
with torch . no_grad ( ) :
reduced_recurrent = F . interpolate ( recurrent_input , scale_factor = 1 / self . scale , mode = ' bicubic ' )
flow_input = torch . stack ( [ input [ self . input_lq_index ] , reduced_recurrent ] , dim = 2 )
flowfield = flow ( flow_input )
recurrent_input = self . resample ( reduced_recurrent . float ( ) , flowfield . float ( ) )
input [ self . input_lq_index ] = torch . cat ( [ input [ self . input_lq_index ] , recurrent_input ] , dim = 1 )
gen_out = gen ( * input )
recurrent_input = gen_out [ self . output_hq_index ]
2020-10-06 01:35:28 +00:00
results . append ( recurrent_input )
return { self . output : results }
# This is the temporal discriminator loss from TecoGAN.
#
# It has a strict contact for 'real' and 'fake' inputs:
# 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset
# 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images.
#
# This loss does the following:
# 1) Picks an image triplet, starting with the first '3' elements in 'real' and 'fake'.
# 2) Uses the image flow generator (specified with 'image_flow_generator') to create detached flow fields for the first and last images in the above sequence.
# 3) Warps the first and last images according to the flow field.
# 4) Composes the three base image and the 2 warped images and middle image into a tensor concatenated at the filter dimension for both real and fake, resulting in a bx18xhxw shape tensor.
# 5) Feeds the catted real and fake image sets into the discriminator, computes a loss, and backward().
# 6) Repeat from (1) until all triplets from the real sequence have been exhausted.
2020-10-07 15:03:30 +00:00
class TecoGanLoss ( ConfigurableLoss ) :
2020-10-06 01:35:28 +00:00
def __init__ ( self , opt , env ) :
2020-10-07 15:03:30 +00:00
super ( TecoGanLoss , self ) . __init__ ( opt , env )
2020-10-06 01:35:28 +00:00
self . criterion = GANLoss ( opt [ ' gan_type ' ] , 1.0 , 0.0 ) . to ( env [ ' device ' ] )
# TecoGAN parameters
2020-10-07 15:03:30 +00:00
self . scale = opt [ ' scale ' ]
self . lr_inputs = opt [ ' lr_inputs ' ]
2020-10-06 01:35:28 +00:00
self . image_flow_generator = opt [ ' image_flow_generator ' ]
self . resampler = Resample2d ( )
2020-10-07 15:03:30 +00:00
self . for_generator = opt [ ' for_generator ' ]
2020-10-06 01:35:28 +00:00
def forward ( self , _ , state ) :
2020-10-07 15:03:30 +00:00
net = self . env [ ' discriminators ' ] [ self . opt [ ' discriminator ' ] ]
2020-10-06 01:35:28 +00:00
flow_gen = self . env [ ' generators ' ] [ self . image_flow_generator ]
real = state [ self . opt [ ' real ' ] ]
2020-10-07 15:03:30 +00:00
fake = torch . stack ( state [ self . opt [ ' fake ' ] ] , dim = 1 )
sequence_len = real . shape [ 1 ]
lr = state [ self . opt [ ' lr_inputs ' ] ]
2020-10-06 01:35:28 +00:00
l_total = 0
2020-10-07 15:03:30 +00:00
for i in range ( sequence_len - 2 ) :
real_sext = create_teco_discriminator_sextuplet ( real , lr , self . scale , i , flow_gen , self . resampler )
fake_sext = create_teco_discriminator_sextuplet ( fake , lr , self . scale , i , flow_gen , self . resampler )
2020-10-06 01:35:28 +00:00
d_fake = net ( fake_sext )
2020-10-07 15:03:30 +00:00
if self . for_generator and self . env [ ' step ' ] % 100 == 0 :
2020-10-07 02:35:39 +00:00
self . produce_teco_visual_debugs ( fake_sext , ' fake ' , i )
self . produce_teco_visual_debugs ( real_sext , ' real ' , i )
2020-10-06 01:35:28 +00:00
if self . opt [ ' gan_type ' ] in [ ' gan ' , ' pixgan ' ] :
self . metrics . append ( ( " d_fake " , torch . mean ( d_fake ) ) )
2020-10-07 15:03:30 +00:00
l_fake = self . criterion ( d_fake , self . for_generator )
if not self . for_generator :
l_real = self . criterion ( d_real , True )
else :
l_real = 0
l_total + = l_fake + l_real
2020-10-06 01:35:28 +00:00
elif self . opt [ ' gan_type ' ] == ' ragan ' :
d_real = net ( real_sext )
d_fake_diff = d_fake - torch . mean ( d_real )
self . metrics . append ( ( " d_fake_diff " , torch . mean ( d_fake_diff ) ) )
2020-10-07 15:03:30 +00:00
l_total + = ( self . criterion ( d_real - torch . mean ( d_fake ) , not self . for_generator ) +
self . criterion ( d_fake_diff , self . for_generator ) )
2020-10-06 01:35:28 +00:00
else :
raise NotImplementedError
2020-10-07 02:35:39 +00:00
2020-10-06 01:35:28 +00:00
return l_total
2020-10-07 02:35:39 +00:00
def produce_teco_visual_debugs ( self , sext , lbl , it ) :
base_path = osp . join ( self . env [ ' base_path ' ] , " visual_dbg " , " teco_sext " , str ( self . env [ ' step ' ] ) , lbl )
os . makedirs ( base_path , exist_ok = True )
lbls = [ ' first ' , ' second ' , ' third ' , ' first_flow ' , ' second_flow ' , ' third_flow ' ]
for i in range ( 6 ) :
torchvision . utils . save_image ( sext [ : , i * 3 : ( i + 1 ) * 3 - 1 , : , : ] , osp . join ( base_path , " %s _ %s .png " % ( lbls [ i ] , it ) ) )
2020-10-06 01:35:28 +00:00
# This loss doesn't have a real entry - only fakes are used.
class PingPongLoss ( ConfigurableLoss ) :
def __init__ ( self , opt , env ) :
super ( PingPongLoss , self ) . __init__ ( opt , env )
self . opt = opt
2020-10-07 15:03:30 +00:00
self . criterion = get_basic_criterion_for_name ( opt [ ' criterion ' ] , env [ ' device ' ] )
2020-10-06 01:35:28 +00:00
def forward ( self , _ , state ) :
fake = state [ self . opt [ ' fake ' ] ]
l_total = 0
2020-10-07 15:03:30 +00:00
for i in range ( ( len ( fake ) - 1 ) / / 2 ) :
2020-10-06 01:35:28 +00:00
early = fake [ i ]
late = fake [ - i ]
l_total + = self . criterion ( early , late )
2020-10-07 02:35:39 +00:00
if self . env [ ' step ' ] % 100 == 0 :
self . produce_teco_visual_debugs ( fake )
return l_total
def produce_teco_visual_debugs ( self , imglist ) :
base_path = osp . join ( self . env [ ' base_path ' ] , " visual_dbg " , " teco_pingpong " , str ( self . env [ ' step ' ] ) )
os . makedirs ( base_path , exist_ok = True )
assert isinstance ( imglist , list )
for i , img in enumerate ( imglist ) :
torchvision . utils . save_image ( img , osp . join ( base_path , " %s .png " % ( i , ) ) )