2020-10-22 20:39:19 +00:00
from torch . cuda . amp import autocast
2020-12-18 16:24:31 +00:00
from models . stylegan . stylegan2_lucidrains import gradient_penalty
2020-12-18 16:18:34 +00:00
from trainer . losses import ConfigurableLoss , GANLoss , extract_params_from_state , get_basic_criterion_for_name
2020-12-18 16:24:31 +00:00
from models . flownet2 . networks import Resample2d
2020-12-30 03:58:02 +00:00
from trainer . inject import Injector
2020-10-06 01:35:28 +00:00
import torch
2020-10-07 15:03:30 +00:00
import torch . nn . functional as F
2020-10-07 02:35:39 +00:00
import os
import os . path as osp
import torchvision
2020-12-18 16:10:44 +00:00
2020-10-06 01:35:28 +00:00
def create_teco_loss ( opt , env ) :
type = opt [ ' type ' ]
2020-10-07 15:03:30 +00:00
if type == ' teco_gan ' :
return TecoGanLoss ( opt , env )
2020-10-06 01:35:28 +00:00
elif type == " teco_pingpong " :
return PingPongLoss ( opt , env )
return None
2020-10-07 15:03:30 +00:00
def create_teco_injector ( opt , env ) :
type = opt [ ' type ' ]
if type == ' teco_recurrent_generated_sequence_injector ' :
return RecurrentImageGeneratorSequenceInjector ( opt , env )
2020-10-10 01:21:43 +00:00
elif type == ' teco_flow_adjustment ' :
return FlowAdjustment ( opt , env )
2020-10-07 15:03:30 +00:00
return None
2020-10-06 01:35:28 +00:00
2020-10-07 15:03:30 +00:00
def extract_inputs_index ( inputs , i ) :
res = [ ]
for input in inputs :
if isinstance ( input , torch . Tensor ) :
res . append ( input [ : , i ] )
else :
res . append ( input )
return res
2020-10-06 01:35:28 +00:00
# Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out]
# Images are fed in sequentially forward and back, resulting in len([out])=2*len([in])-1 (last element is not repeated).
# All computation is done with torch.no_grad().
class RecurrentImageGeneratorSequenceInjector ( Injector ) :
def __init__ ( self , opt , env ) :
super ( RecurrentImageGeneratorSequenceInjector , self ) . __init__ ( opt , env )
self . flow = opt [ ' flow_network ' ]
2020-10-07 15:03:30 +00:00
self . input_lq_index = opt [ ' input_lq_index ' ] if ' input_lq_index ' in opt . keys ( ) else 0
2020-10-11 02:21:09 +00:00
self . recurrent_index = opt [ ' recurrent_index ' ]
2020-11-24 16:24:02 +00:00
self . output_hq_index = opt [ ' output_hq_index ' ] if ' output_hq_index ' in opt . keys ( ) else 0
self . output_recurrent_index = opt [ ' output_recurrent_index ' ] if ' output_recurrent_index ' in opt . keys ( ) else self . output_hq_index
2020-10-07 15:03:30 +00:00
self . scale = opt [ ' scale ' ]
2020-10-06 01:35:28 +00:00
self . resample = Resample2d ( )
2020-11-24 16:24:02 +00:00
self . flow_key = opt [ ' flow_input_key ' ] if ' flow_input_key ' in opt . keys ( ) else None
2020-10-08 05:11:58 +00:00
self . first_inputs = opt [ ' first_inputs ' ] if ' first_inputs ' in opt . keys ( ) else opt [ ' in ' ] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
self . do_backwards = opt [ ' do_backwards ' ] if ' do_backwards ' in opt . keys ( ) else True
2020-10-12 23:36:43 +00:00
self . hq_recurrent = opt [ ' hq_recurrent ' ] if ' hq_recurrent ' in opt . keys ( ) else False # When True, recurrent_index is not touched for the first iteration, allowing you to specify what is fed in. When False, zeros are fed into the recurrent index.
2020-11-14 03:10:12 +00:00
self . hq_batched_output_key = opt [ ' hq_batched_key ' ] if ' hq_batched_key ' in opt . keys ( ) else None
2020-10-06 01:35:28 +00:00
def forward ( self , state ) :
gen = self . env [ ' generators ' ] [ self . opt [ ' generator ' ] ]
flow = self . env [ ' generators ' ] [ self . flow ]
2020-10-08 05:11:58 +00:00
first_inputs = extract_params_from_state ( self . first_inputs , state )
2020-10-07 15:03:30 +00:00
inputs = extract_params_from_state ( self . input , state )
if not isinstance ( inputs , list ) :
inputs = [ inputs ]
2020-10-06 01:35:28 +00:00
2020-10-11 04:39:55 +00:00
if not isinstance ( self . output , list ) :
self . output = [ self . output ]
results = { }
for out_key in self . output :
results [ out_key ] = [ ]
2020-10-06 01:35:28 +00:00
# Go forward in the sequence first.
first_step = True
2020-10-07 15:03:30 +00:00
b , f , c , h , w = inputs [ self . input_lq_index ] . shape
2020-10-07 18:41:17 +00:00
debug_index = 0
2020-10-07 15:03:30 +00:00
for i in range ( f ) :
2020-10-06 01:35:28 +00:00
if first_step :
2020-10-08 05:11:58 +00:00
input = extract_inputs_index ( first_inputs , i )
2020-10-12 23:36:43 +00:00
if self . hq_recurrent :
recurrent_input = input [ self . recurrent_index ]
else :
recurrent_input = torch . zeros_like ( input [ self . recurrent_index ] )
2020-10-06 01:35:28 +00:00
first_step = False
else :
2020-10-08 05:11:58 +00:00
input = extract_inputs_index ( inputs , i )
2020-10-22 22:15:24 +00:00
with torch . no_grad ( ) and autocast ( enabled = False ) :
2020-11-24 16:24:02 +00:00
if self . flow_key is not None :
flow_input = state [ self . flow_key ] [ : , i ]
2020-10-26 17:09:55 +00:00
else :
flow_input = input [ self . input_lq_index ]
2020-11-24 16:24:02 +00:00
reduced_recurrent = F . interpolate ( hq_recurrent , scale_factor = 1 / self . scale , mode = ' bicubic ' )
2020-10-26 17:09:55 +00:00
flow_input = torch . stack ( [ flow_input , reduced_recurrent ] , dim = 2 ) . float ( )
2020-11-24 16:24:02 +00:00
flowfield = flow ( flow_input )
if recurrent_input . shape [ - 1 ] != flow_input . shape [ - 1 ] :
flowfield = F . interpolate ( flowfield , scale_factor = self . scale , mode = ' bicubic ' )
2020-10-22 22:15:24 +00:00
recurrent_input = self . resample ( recurrent_input . float ( ) , flowfield )
2020-10-11 02:21:09 +00:00
input [ self . recurrent_index ] = recurrent_input
2020-10-08 05:11:58 +00:00
if self . env [ ' step ' ] % 50 == 0 :
2020-11-24 16:24:02 +00:00
if input [ self . input_lq_index ] . shape [ 1 ] == 3 : # Only debug this if we're dealing with images.
self . produce_teco_visual_debugs ( input [ self . input_lq_index ] , input [ self . hq_recurrent ] , debug_index )
debug_index + = 1
2020-10-22 20:39:19 +00:00
with autocast ( enabled = self . env [ ' opt ' ] [ ' fp16 ' ] ) :
gen_out = gen ( * input )
2020-10-08 15:26:25 +00:00
if isinstance ( gen_out , torch . Tensor ) :
gen_out = [ gen_out ]
2020-10-11 04:39:55 +00:00
for i , out_key in enumerate ( self . output ) :
results [ out_key ] . append ( gen_out [ i ] )
2020-11-24 16:24:02 +00:00
hq_recurrent = gen_out [ self . output_hq_index ]
recurrent_input = gen_out [ self . output_recurrent_index ]
2020-10-06 01:35:28 +00:00
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
2020-10-08 05:11:58 +00:00
if self . do_backwards :
it = reversed ( range ( f - 1 ) )
for i in it :
input = extract_inputs_index ( inputs , i )
with torch . no_grad ( ) :
2020-10-22 22:15:24 +00:00
with autocast ( enabled = False ) :
2020-11-24 16:24:02 +00:00
if self . flow_key is not None :
flow_input = state [ self . flow_key ] [ : , i ]
2020-10-26 17:09:55 +00:00
else :
flow_input = input [ self . input_lq_index ]
2020-11-24 16:24:02 +00:00
reduced_recurrent = F . interpolate ( hq_recurrent , scale_factor = 1 / self . scale , mode = ' bicubic ' )
2020-10-26 17:09:55 +00:00
flow_input = torch . stack ( [ flow_input , reduced_recurrent ] , dim = 2 ) . float ( )
2020-11-24 16:24:02 +00:00
flowfield = flow ( flow_input )
if recurrent_input . shape [ - 1 ] != flow_input . shape [ - 1 ] :
flowfield = F . interpolate ( flow ( flow_input ) , scale_factor = self . scale , mode = ' bicubic ' )
2020-10-22 22:15:24 +00:00
recurrent_input = self . resample ( recurrent_input . float ( ) , flowfield )
input [ self . recurrent_index ] = recurrent_input
2020-10-08 05:11:58 +00:00
if self . env [ ' step ' ] % 50 == 0 :
2020-11-24 16:24:02 +00:00
if input [ self . input_lq_index ] . shape [ 1 ] == 3 : # Only debug this if we're dealing with images.
self . produce_teco_visual_debugs ( input [ self . input_lq_index ] , input [ self . recurrent_index ] , debug_index )
debug_index + = 1
2020-10-22 20:39:19 +00:00
with autocast ( enabled = self . env [ ' opt ' ] [ ' fp16 ' ] ) :
gen_out = gen ( * input )
2020-10-08 15:26:25 +00:00
if isinstance ( gen_out , torch . Tensor ) :
gen_out = [ gen_out ]
2020-10-11 04:39:55 +00:00
for i , out_key in enumerate ( self . output ) :
results [ out_key ] . append ( gen_out [ i ] )
2020-11-24 16:24:02 +00:00
hq_recurrent = gen_out [ self . output_hq_index ]
recurrent_input = gen_out [ self . output_recurrent_index ]
2020-10-06 01:35:28 +00:00
2020-10-28 04:40:15 +00:00
final_results = { }
# Include 'hq_batched' here - because why not... Don't really need a separate injector for this.
b , s , c , h , w = state [ ' hq ' ] . shape
2020-11-14 03:10:12 +00:00
if self . hq_batched_output_key is not None :
final_results [ self . hq_batched_output_key ] = state [ ' hq ' ] . clone ( ) . permute ( 1 , 0 , 2 , 3 , 4 ) . reshape ( b * s , c , h , w )
2020-10-11 04:39:55 +00:00
for k , v in results . items ( ) :
2020-10-28 04:40:15 +00:00
final_results [ k ] = torch . stack ( v , dim = 1 )
final_results [ k + " _batched " ] = torch . cat ( v [ : s ] , dim = 0 ) # Only include the original sequence - this output is generally used to compare against HQ.
return final_results
2020-10-06 01:35:28 +00:00
2020-10-11 02:30:14 +00:00
def produce_teco_visual_debugs ( self , gen_input , gen_recurrent , it ) :
2020-10-11 02:21:09 +00:00
if self . env [ ' rank ' ] > 0 :
2020-10-08 20:32:45 +00:00
return
2020-12-18 16:18:34 +00:00
base_path = osp . join ( self . env [ ' base_path ' ] , " ../../models " , " visual_dbg " , " teco_geninput " , str ( self . env [ ' step ' ] ) )
2020-10-07 18:41:17 +00:00
os . makedirs ( base_path , exist_ok = True )
2020-10-22 19:58:05 +00:00
torchvision . utils . save_image ( gen_input . float ( ) , osp . join ( base_path , " %s _img.png " % ( it , ) ) )
torchvision . utils . save_image ( gen_recurrent . float ( ) , osp . join ( base_path , " %s _recurrent.png " % ( it , ) ) )
2020-10-07 18:41:17 +00:00
2020-10-06 01:35:28 +00:00
2020-10-10 01:21:43 +00:00
class FlowAdjustment ( Injector ) :
def __init__ ( self , opt , env ) :
super ( FlowAdjustment , self ) . __init__ ( opt , env )
self . resample = Resample2d ( )
self . flow = opt [ ' flow_network ' ]
self . flow_target = opt [ ' flow_target ' ]
self . flowed = opt [ ' flowed ' ]
def forward ( self , state ) :
2020-10-22 22:15:24 +00:00
with autocast ( enabled = False ) :
flow = self . env [ ' generators ' ] [ self . flow ]
flow_target = state [ self . flow_target ]
flowed = F . interpolate ( state [ self . flowed ] , size = flow_target . shape [ 2 : ] , mode = ' bicubic ' )
flow_input = torch . stack ( [ flow_target , flowed ] , dim = 2 ) . float ( )
flowfield = F . interpolate ( flow ( flow_input ) , size = state [ self . flowed ] . shape [ 2 : ] , mode = ' bicubic ' )
return { self . output : self . resample ( state [ self . flowed ] , flowfield ) }
2020-10-10 01:21:43 +00:00
2020-10-28 04:40:15 +00:00
def create_teco_discriminator_sextuplet ( input_list , lr_imgs , scale , index , flow_gen , resampler , margin ) :
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
with autocast ( enabled = False ) :
triplet = input_list [ : , index : index + 3 ] . float ( )
first_flow = flow_gen ( torch . stack ( [ triplet [ : , 1 ] , triplet [ : , 0 ] ] , dim = 2 ) )
last_flow = flow_gen ( torch . stack ( [ triplet [ : , 1 ] , triplet [ : , 2 ] ] , dim = 2 ) )
flow_triplet = [ resampler ( triplet [ : , 0 ] , first_flow ) ,
triplet [ : , 1 ] ,
resampler ( triplet [ : , 2 ] , last_flow ) ]
flow_triplet = torch . stack ( flow_triplet , dim = 1 )
combined = torch . cat ( [ triplet , flow_triplet ] , dim = 1 )
b , f , c , h , w = combined . shape
combined = combined . view ( b , 3 * 6 , h , w ) # 3*6 is essentially an assertion here.
# Apply margin
return combined [ : , : , margin : - margin , margin : - margin ]
def create_all_discriminator_sextuplets ( input_list , lr_imgs , scale , total , flow_gen , resampler , margin ) :
2020-10-30 05:02:20 +00:00
with autocast ( enabled = False ) :
input_list = input_list . float ( )
# Combine everything and feed it into the flow network at once for better efficiency.
batch_sz = input_list . shape [ 0 ]
flux_doubles_forward = [ torch . stack ( [ input_list [ : , i ] , input_list [ : , i + 1 ] ] , dim = 2 ) for i in range ( 1 , total + 1 ) ]
flux_doubles_backward = [ torch . stack ( [ input_list [ : , i ] , input_list [ : , i - 1 ] ] , dim = 2 ) for i in range ( 1 , total + 1 ) ]
flows_forward = flow_gen ( torch . cat ( flux_doubles_forward , dim = 0 ) )
flows_backward = flow_gen ( torch . cat ( flux_doubles_backward , dim = 0 ) )
sexts = [ ]
for i in range ( total ) :
flow_forward = flows_forward [ batch_sz * i : batch_sz * ( i + 1 ) ]
flow_backward = flows_backward [ batch_sz * i : batch_sz * ( i + 1 ) ]
mid = input_list [ : , i + 1 ]
sext = torch . stack ( [ input_list [ : , i ] , mid , input_list [ : , i + 2 ] ,
resampler ( input_list [ : , i ] , flow_backward ) ,
mid ,
resampler ( input_list [ : , i + 2 ] , flow_forward ) ] , dim = 1 )
# Apply margin
b , f , c , h , w = sext . shape
sext = sext . view ( b , 3 * 6 , h , w ) # f*c = 6*3
sext = sext [ : , : , margin : - margin , margin : - margin ]
sexts . append ( sext )
2020-10-28 04:40:15 +00:00
return torch . cat ( sexts , dim = 0 )
2020-10-06 01:35:28 +00:00
# This is the temporal discriminator loss from TecoGAN.
#
2020-10-08 05:11:58 +00:00
# It has a strict contract for 'real' and 'fake' inputs:
2020-10-06 01:35:28 +00:00
# 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset
# 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images.
#
# This loss does the following:
# 1) Picks an image triplet, starting with the first '3' elements in 'real' and 'fake'.
# 2) Uses the image flow generator (specified with 'image_flow_generator') to create detached flow fields for the first and last images in the above sequence.
# 3) Warps the first and last images according to the flow field.
# 4) Composes the three base image and the 2 warped images and middle image into a tensor concatenated at the filter dimension for both real and fake, resulting in a bx18xhxw shape tensor.
# 5) Feeds the catted real and fake image sets into the discriminator, computes a loss, and backward().
# 6) Repeat from (1) until all triplets from the real sequence have been exhausted.
2020-10-07 15:03:30 +00:00
class TecoGanLoss ( ConfigurableLoss ) :
2020-10-06 01:35:28 +00:00
def __init__ ( self , opt , env ) :
2020-10-07 15:03:30 +00:00
super ( TecoGanLoss , self ) . __init__ ( opt , env )
2020-10-06 01:35:28 +00:00
self . criterion = GANLoss ( opt [ ' gan_type ' ] , 1.0 , 0.0 ) . to ( env [ ' device ' ] )
# TecoGAN parameters
2020-10-07 15:03:30 +00:00
self . scale = opt [ ' scale ' ]
self . lr_inputs = opt [ ' lr_inputs ' ]
2020-10-06 01:35:28 +00:00
self . image_flow_generator = opt [ ' image_flow_generator ' ]
self . resampler = Resample2d ( )
2020-10-07 15:03:30 +00:00
self . f or_generator = opt [ ' for_generator ' ]
2020-10-11 02:21:09 +00:00
self . min_loss = opt [ ' min_loss ' ] if ' min_loss ' in opt . keys ( ) else 0
2020-10-07 18:41:17 +00:00
self . margin = opt [ ' margin ' ] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
2020-10-28 04:40:15 +00:00
self . ff = opt [ ' fast_forward ' ] if ' fast_forward ' in opt . keys ( ) else False
2020-10-28 04:48:23 +00:00
self . noise = opt [ ' noise ' ] if ' noise ' in opt . keys ( ) else 0
2020-11-19 18:35:05 +00:00
self . gradient_penalty = opt [ ' gradient_penalty ' ] if ' gradient_penalty ' in opt . keys ( ) else False
2020-10-06 01:35:28 +00:00
def forward ( self , _ , state ) :
2020-10-28 04:40:15 +00:00
if self . ff :
return self . fast_forward ( state )
else :
return self . lowmem_forward ( state )
# Computes the discriminator loss one recursive step at a time, which has a lower memory overhead but is
# slower.
def lowmem_forward ( self , state ) :
2020-10-06 01:35:28 +00:00
flow_gen = self . env [ ' generators ' ] [ self . image_flow_generator ]
real = state [ self . opt [ ' real ' ] ]
2020-10-11 04:39:55 +00:00
fake = state [ self . opt [ ' fake ' ] ]
2020-10-07 15:03:30 +00:00
sequence_len = real . shape [ 1 ]
lr = state [ self . opt [ ' lr_inputs ' ] ]
2020-10-06 01:35:28 +00:00
l_total = 0
2020-10-28 04:40:15 +00:00
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
2020-10-07 15:03:30 +00:00
for i in range ( sequence_len - 2 ) :
2020-10-22 22:49:34 +00:00
real_sext = create_teco_discriminator_sextuplet ( real , lr , self . scale , i , flow_gen , self . resampler , self . margin )
2020-11-19 18:35:05 +00:00
if self . gradient_penalty :
real_sext . requires_grad_ ( )
2020-10-22 22:49:34 +00:00
fake_sext = create_teco_discriminator_sextuplet ( fake , lr , self . scale , i , flow_gen , self . resampler , self . margin )
2020-11-19 18:35:05 +00:00
l_step , d_real = self . compute_loss ( real_sext , fake_sext )
2020-10-11 02:21:09 +00:00
if l_step > self . min_loss :
2020-11-19 18:35:05 +00:00
l_total = l_total + l_step
elif self . gradient_penalty :
gp = gradient_penalty ( real_sext , d_real )
l_total = l_total + gp
2020-10-07 02:35:39 +00:00
2020-10-06 01:35:28 +00:00
return l_total
2020-10-28 04:40:15 +00:00
# Computes the discriminator loss by dogpiling all of the sextuplets into the batch dimension and doing one massive
# forward() on the discriminators. High memory but faster.
def fast_forward ( self , state ) :
flow_gen = self . env [ ' generators ' ] [ self . image_flow_generator ]
real = state [ self . opt [ ' real ' ] ]
fake = state [ self . opt [ ' fake ' ] ]
sequence_len = real . shape [ 1 ]
lr = state [ self . opt [ ' lr_inputs ' ] ]
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
combined_real_sext = create_all_discriminator_sextuplets ( real , lr , self . scale , sequence_len - 2 , flow_gen ,
self . resampler , self . margin )
2020-11-19 18:35:05 +00:00
if self . gradient_penalty :
combined_real_sext . requires_grad_ ( )
2020-10-28 04:40:15 +00:00
combined_fake_sext = create_all_discriminator_sextuplets ( fake , lr , self . scale , sequence_len - 2 , flow_gen ,
self . resampler , self . margin )
2020-11-19 18:35:05 +00:00
l_total , d_real = self . compute_loss ( combined_real_sext , combined_fake_sext )
2020-10-28 04:40:15 +00:00
if l_total < self . min_loss :
l_total = 0
2020-11-19 18:35:05 +00:00
elif self . gradient_penalty :
gp = gradient_penalty ( combined_real_sext , d_real )
l_total = l_total + gp
2020-10-28 04:40:15 +00:00
return l_total
def compute_loss ( self , real_sext , fake_sext ) :
fp16 = self . env [ ' opt ' ] [ ' fp16 ' ]
net = self . env [ ' discriminators ' ] [ self . opt [ ' discriminator ' ] ]
2020-10-28 04:48:23 +00:00
if self . noise != 0 :
2020-11-19 18:35:05 +00:00
real_sext = real_sext + torch . rand_like ( real_sext ) * self . noise
fake_sext = fake_sext + torch . rand_like ( fake_sext ) * self . noise
2020-10-28 04:40:15 +00:00
with autocast ( enabled = fp16 ) :
d_fake = net ( fake_sext )
d_real = net ( real_sext )
self . metrics . append ( ( " d_fake " , torch . mean ( d_fake ) ) )
self . metrics . append ( ( " d_real " , torch . mean ( d_real ) ) )
if self . for_generator and self . env [ ' step ' ] % 50 == 0 :
self . produce_teco_visual_debugs ( fake_sext , ' fake ' , 0 )
self . produce_teco_visual_debugs ( real_sext , ' real ' , 0 )
if self . opt [ ' gan_type ' ] in [ ' gan ' , ' pixgan ' ] :
l_fake = self . criterion ( d_fake , self . for_generator )
if not self . for_generator :
l_real = self . criterion ( d_real , True )
else :
l_real = 0
l_step = l_fake + l_real
elif self . opt [ ' gan_type ' ] == ' ragan ' :
d_fake_diff = d_fake - torch . mean ( d_real )
self . metrics . append ( ( " d_fake_diff " , torch . mean ( d_fake_diff ) ) )
l_step = ( self . criterion ( d_real - torch . mean ( d_fake ) , not self . for_generator ) +
self . criterion ( d_fake_diff , self . for_generator ) )
else :
raise NotImplementedError
2020-11-19 18:35:05 +00:00
return l_step , d_real
2020-10-28 04:40:15 +00:00
2020-10-07 02:35:39 +00:00
def produce_teco_visual_debugs ( self , sext , lbl , it ) :
2020-10-11 02:21:09 +00:00
if self . env [ ' rank ' ] > 0 :
2020-10-08 20:32:45 +00:00
return
2020-12-18 16:18:34 +00:00
base_path = osp . join ( self . env [ ' base_path ' ] , " ../../models " , " visual_dbg " , " teco_sext " , str ( self . env [ ' step ' ] ) , lbl )
2020-10-07 02:35:39 +00:00
os . makedirs ( base_path , exist_ok = True )
2020-10-07 18:41:17 +00:00
lbls = [ ' img_a ' , ' img_b ' , ' img_c ' , ' flow_a ' , ' flow_b ' , ' flow_c ' ]
2020-10-07 02:35:39 +00:00
for i in range ( 6 ) :
2020-10-22 19:58:05 +00:00
torchvision . utils . save_image ( sext [ : , i * 3 : ( i + 1 ) * 3 , : , : ] . float ( ) , osp . join ( base_path , " %s _ %s .png " % ( it , lbls [ i ] ) ) )
2020-10-07 02:35:39 +00:00
2020-10-06 01:35:28 +00:00
# This loss doesn't have a real entry - only fakes are used.
class PingPongLoss ( ConfigurableLoss ) :
def __init__ ( self , opt , env ) :
super ( PingPongLoss , self ) . __init__ ( opt , env )
self . opt = opt
2020-10-07 15:03:30 +00:00
self . criterion = get_basic_criterion_for_name ( opt [ ' criterion ' ] , env [ ' device ' ] )
2020-10-06 01:35:28 +00:00
def forward ( self , _ , state ) :
fake = state [ self . opt [ ' fake ' ] ]
l_total = 0
2020-10-11 14:20:07 +00:00
img_count = fake . shape [ 1 ]
for i in range ( ( img_count - 1 ) / / 2 ) :
early = fake [ : , i ]
2020-10-30 06:19:58 +00:00
late = fake [ : , - ( i + 1 ) ]
2020-10-06 01:35:28 +00:00
l_total + = self . criterion ( early , late )
2020-10-31 17:09:10 +00:00
#if self.env['step'] % 50 == 0:
# self.produce_teco_visual_debugs2(early, late, i)
2020-10-07 02:35:39 +00:00
2020-10-08 05:11:58 +00:00
if self . env [ ' step ' ] % 50 == 0 :
2020-10-07 02:35:39 +00:00
self . produce_teco_visual_debugs ( fake )
return l_total
def produce_teco_visual_debugs ( self , imglist ) :
2020-10-11 02:21:09 +00:00
if self . env [ ' rank ' ] > 0 :
2020-10-08 20:32:45 +00:00
return
2020-12-18 16:18:34 +00:00
base_path = osp . join ( self . env [ ' base_path ' ] , " ../../models " , " visual_dbg " , " teco_pingpong " , str ( self . env [ ' step ' ] ) )
2020-10-07 02:35:39 +00:00
os . makedirs ( base_path , exist_ok = True )
2020-10-11 14:20:07 +00:00
cnt = imglist . shape [ 1 ]
for i in range ( cnt ) :
img = imglist [ : , i ]
2020-10-22 19:58:05 +00:00
torchvision . utils . save_image ( img . float ( ) , osp . join ( base_path , " %s .png " % ( i , ) ) )
2020-10-13 16:07:49 +00:00
2020-10-30 06:19:58 +00:00
def produce_teco_visual_debugs2 ( self , imga , imgb , i ) :
if self . env [ ' rank ' ] > 0 :
return
2020-12-18 16:18:34 +00:00
base_path = osp . join ( self . env [ ' base_path ' ] , " ../../models " , " visual_dbg " , " teco_pingpong " , str ( self . env [ ' step ' ] ) )
2020-10-30 06:19:58 +00:00
os . makedirs ( base_path , exist_ok = True )
torchvision . utils . save_image ( imga . float ( ) , osp . join ( base_path , " %s _a.png " % ( i , ) ) )
torchvision . utils . save_image ( imgb . float ( ) , osp . join ( base_path , " %s _b.png " % ( i , ) ) )