2020-07-18 13:23:26 +00:00
import models . archs . SwitchedResidualGenerator_arch as srg
import torch
import torch . nn as nn
from switched_conv_util import save_attention_to_image
from switched_conv import compute_attention_specificity
2020-09-08 14:17:27 +00:00
from models . archs . arch_util import ConvGnLelu , ExpansionBlock , MultiConvBlock
2020-07-18 13:23:26 +00:00
import functools
import torch . nn . functional as F
# Some notes about this new architecture:
# 1) Discriminator is going to need to get update_for_step() called.
# 2) Not sure if pixgan part of discriminator is going to work properly, make sure to test at multiple add levels.
# 3) Also not sure if growth modules will be properly saved/trained, be sure to test this.
# 4) start_step will need to get set properly when constructing these models, even when resuming - OR another method needs to be added to resume properly.
class GrowingSRGBase ( nn . Module ) :
2020-07-18 20:18:48 +00:00
def __init__ ( self , progressive_step_schedule , switch_reductions , growth_fade_in_steps , switch_filters , switch_processing_layers , trans_counts ,
2020-07-18 13:23:26 +00:00
trans_layers , transformation_filters , initial_temp = 20 , final_temperature_step = 50000 , upsample_factor = 1 ,
add_scalable_noise_to_transforms = False , start_step = 0 ) :
super ( GrowingSRGBase , self ) . __init__ ( )
switches = [ ]
self . initial_conv = ConvGnLelu ( 3 , transformation_filters , norm = False , activation = False , bias = True )
self . upconv1 = ConvGnLelu ( transformation_filters , transformation_filters , norm = False , bias = True )
self . upconv2 = ConvGnLelu ( transformation_filters , transformation_filters , norm = False , bias = True )
self . hr_conv = ConvGnLelu ( transformation_filters , transformation_filters , norm = False , bias = True )
self . final_conv = ConvGnLelu ( transformation_filters , 3 , norm = False , activation = False , bias = True )
self . switch_filters = switch_filters
self . switch_processing_layers = switch_processing_layers
self . trans_layers = trans_layers
self . transformation_filters = transformation_filters
2020-07-18 20:18:48 +00:00
self . progressive_schedule = progressive_step_schedule
self . switch_reductions = switch_reductions # This lists the reductions for all switches (even ones not activated yet).
2020-07-18 13:23:26 +00:00
self . growth_fade_in_per_step = 1 / growth_fade_in_steps
self . transformation_counts = trans_counts
self . init_temperature = initial_temp
self . final_temperature_step = final_temperature_step
self . attentions = None
self . upsample_factor = upsample_factor
self . add_noise_to_transform = add_scalable_noise_to_transforms
2020-07-18 20:18:48 +00:00
self . start_step = start_step
self . latest_step = start_step
self . fades = [ ]
2020-07-22 17:40:42 +00:00
self . counter = 0
2020-07-18 13:23:26 +00:00
assert self . upsample_factor == 2 or self . upsample_factor == 4
2020-07-18 20:18:48 +00:00
switches = [ ]
for i , ( step , reductions ) in enumerate ( zip ( progressive_step_schedule , switch_reductions ) ) :
multiplx_fn = functools . partial ( srg . ConvBasisMultiplexer , self . transformation_filters , self . switch_filters ,
2020-07-18 13:23:26 +00:00
reductions , self . switch_processing_layers , self . transformation_counts )
2020-07-18 20:18:48 +00:00
pretransform_fn = functools . partial ( ConvGnLelu , self . transformation_filters , self . transformation_filters , norm = False ,
bias = False , weight_init_factor = .1 )
transform_fn = functools . partial ( srg . MultiConvBlock , self . transformation_filters , int ( self . transformation_filters * 1.5 ) ,
self . transformation_filters , kernel_size = 3 , depth = self . trans_layers ,
weight_init_factor = .1 )
switches . append ( srg . ConfigurableSwitchComputer ( self . transformation_filters , multiplx_fn ,
pre_transform_block = pretransform_fn ,
transform_block = transform_fn ,
transform_count = self . transformation_counts , init_temp = self . init_temperature ,
add_scalable_noise_to_transforms = self . add_noise_to_transform ,
attention_norm = False ) )
self . progressive_switches = nn . ModuleList ( switches )
def get_param_groups ( self ) :
param_groups = [ ]
base_param_group = [ ]
for k , v in self . named_parameters ( ) :
if " progressive_switches " not in k and v . requires_grad :
base_param_group . append ( v )
param_groups . append ( { ' params ' : base_param_group } )
for i , sw in enumerate ( self . progressive_switches ) :
sw_param_group = [ ]
for k , v in sw . named_parameters ( ) :
if v . requires_grad :
sw_param_group . append ( v )
param_groups . append ( { ' params ' : sw_param_group } )
return param_groups
2020-07-22 17:39:45 +00:00
# This is a hacky way of modifying the underlying model while training. Since changing the model means changing
# the optimizer and the scheduler, these things are fed in. For ProgressiveSrg, this function adds an additional
# switch to the end of the chain with depth=3 and an online time set at the end fo the function.
def update_model ( self , opt , sched ) :
multiplx_fn = functools . partial ( srg . ConvBasisMultiplexer , self . transformation_filters , self . switch_filters ,
3 , self . switch_processing_layers , self . transformation_counts )
pretransform_fn = functools . partial ( ConvGnLelu , self . transformation_filters , self . transformation_filters , norm = False ,
bias = False , weight_init_factor = .1 )
transform_fn = functools . partial ( srg . MultiConvBlock , self . transformation_filters , int ( self . transformation_filters * 1.5 ) ,
self . transformation_filters , kernel_size = 3 , depth = self . trans_layers ,
weight_init_factor = .1 )
new_sw = srg . ConfigurableSwitchComputer ( self . transformation_filters , multiplx_fn ,
pre_transform_block = pretransform_fn ,
transform_block = transform_fn ,
transform_count = self . transformation_counts , init_temp = self . init_temperature ,
add_scalable_noise_to_transforms = self . add_noise_to_transform ,
attention_norm = False ) . to ( ' cuda ' )
self . progressive_switches . append ( new_sw )
new_sw_param_group = [ ]
for k , v in new_sw . named_parameters ( ) :
if v . requires_grad :
new_sw_param_group . append ( v )
opt . add_param_group ( { ' params ' : new_sw_param_group } )
self . progressive_schedule . append ( 150000 )
sched . group_starts . append ( 150000 )
2020-07-18 20:18:48 +00:00
def get_progressive_starts ( self ) :
# The base param group starts at step 0, the rest are defined via progressive_switches.
return [ 0 ] + self . progressive_schedule
2020-07-18 13:23:26 +00:00
2020-07-22 17:40:42 +00:00
# This method turns requires_grad on and off for different switches, allowing very large models to be trained while
# using less memory. When used in conjunction with gradient accumulation, it becomes a form of model parallelism.
# <groups> controls the proportion of switches that are enabled. 1/groups will be enabled.
# Switches that are younger than 40000 steps are not eligible to be turned off.
def do_switched_grad ( self , groups = 1 ) :
# If requires_grad is already disabled, don't bother.
if not self . initial_conv . conv . weight . requires_grad or groups == 1 :
return
self . counter = ( self . counter + 1 ) % groups
enabled = [ ]
for i , sw in enumerate ( self . progressive_switches ) :
if self . latest_step - self . progressive_schedule [ i ] > 40000 and i % groups != self . counter :
for p in sw . parameters ( ) :
p . requires_grad = False
else :
enabled . append ( i )
for p in sw . parameters ( ) :
p . requires_grad = True
2020-07-18 13:23:26 +00:00
def forward ( self , x ) :
2020-07-22 17:40:42 +00:00
self . do_switched_grad ( 2 )
2020-07-18 13:23:26 +00:00
x = self . initial_conv ( x )
self . attentions = [ ]
2020-07-18 20:18:48 +00:00
self . fades = [ ]
self . enabled_switches = 0
for i , sw in enumerate ( self . progressive_switches ) :
fade_in = 1 if self . progressive_schedule [ i ] == 0 else 0
2020-07-18 13:23:26 +00:00
if self . latest_step > 0 and self . progressive_schedule [ i ] != 0 :
switch_age = self . latest_step - self . progressive_schedule [ i ]
fade_in = min ( 1 , switch_age * self . growth_fade_in_per_step )
2020-07-18 20:18:48 +00:00
if fade_in > 0 :
self . enabled_switches + = 1
x , att = sw . forward ( x , True , fixed_scale = fade_in )
self . attentions . append ( att )
self . fades . append ( fade_in )
2020-07-18 13:23:26 +00:00
x = self . upconv1 ( F . interpolate ( x , scale_factor = 2 , mode = " nearest " ) )
if self . upsample_factor > 2 :
x = F . interpolate ( x , scale_factor = 2 , mode = " nearest " )
x = self . upconv2 ( x )
x = self . final_conv ( self . hr_conv ( x ) )
return x , x
def update_for_step ( self , step , experiments_path = ' . ' ) :
2020-07-18 20:18:48 +00:00
self . latest_step = step + self . start_step
2020-07-18 13:23:26 +00:00
# Set the temperature of the switches, per-layer.
2020-07-18 20:18:48 +00:00
for i , ( first_step , sw ) in enumerate ( zip ( self . progressive_schedule , self . progressive_switches ) ) :
2020-07-18 13:23:26 +00:00
temp_loss_per_step = ( self . init_temperature - 1 ) / self . final_temperature_step
2020-07-18 20:18:48 +00:00
sw . set_temperature ( min ( self . init_temperature ,
max ( self . init_temperature - temp_loss_per_step * ( step - first_step ) , 1 ) ) )
2020-07-18 13:23:26 +00:00
# Save attention images.
if self . attentions is not None and step % 50 == 0 :
[ save_attention_to_image ( experiments_path , self . attentions [ i ] , self . transformation_counts , step , " a %i " % ( i + 1 , ) , l_mult = 10 ) for i in range ( len ( self . attentions ) ) ]
def get_debug_values ( self , step ) :
mean_hists = [ compute_attention_specificity ( att , 2 ) for att in self . attentions ]
means = [ i [ 0 ] for i in mean_hists ]
hists = [ i [ 1 ] . clone ( ) . detach ( ) . cpu ( ) . flatten ( ) for i in mean_hists ]
val = { }
for i in range ( len ( means ) ) :
val [ " switch_ %i _specificity " % ( i , ) ] = means [ i ]
val [ " switch_ %i _histogram " % ( i , ) ] = hists [ i ]
2020-07-18 20:18:48 +00:00
val [ " switch_ %i _temperature " % ( i , ) ] = self . progressive_switches [ i ] . switch . temperature
for i , f in enumerate ( self . fades ) :
val [ " switch_ %i _fade " % ( i , ) ] = f
val [ " enabled_switches " ] = self . enabled_switches
2020-07-18 13:23:26 +00:00
return val
class DiscriminatorDownsample ( nn . Module ) :
def __init__ ( self , base_filters , end_filters ) :
self . conv0 = ConvGnLelu ( base_filters , end_filters , kernel_size = 3 , bias = False )
self . conv1 = ConvGnLelu ( end_filters , end_filters , kernel_size = 3 , stride = 2 , bias = False )
def forward ( self , x ) :
return self . conv1 ( self . conv0 ( x ) )
class DiscriminatorUpsample ( nn . Module ) :
def __init__ ( self , base_filters , end_filters ) :
self . up = ExpansionBlock ( base_filters , end_filters , block = ConvGnLelu )
self . proc = ConvGnLelu ( end_filters , end_filters , bias = False )
self . collapse = ConvGnLelu ( end_filters , 1 , bias = True , norm = False , activation = False )
def forward ( self , x , ff ) :
x = self . up1 ( x , ff )
return x , self . collapse1 ( self . proc1 ( x ) )
class GrowingUnetDiscBase ( nn . Module ) :
def __init__ ( self , nf , growing_schedule , growth_fade_in_steps , start_step = 0 ) :
super ( GrowingUnetDiscBase , self ) . __init__ ( )
# [64, 128, 128]
self . conv0_0 = ConvGnLelu ( 3 , nf , kernel_size = 3 , bias = True , activation = False )
self . conv0_1 = ConvGnLelu ( nf , nf , kernel_size = 3 , stride = 2 , bias = False )
# [64, 64, 64]
self . conv1_0 = ConvGnLelu ( nf , nf * 2 , kernel_size = 3 , bias = False )
self . conv1_1 = ConvGnLelu ( nf * 2 , nf * 2 , kernel_size = 3 , stride = 2 , bias = False )
self . down_base = DiscriminatorDownsample ( nf * 2 , nf * 4 )
self . up_base = DiscriminatorUpsample ( nf * 4 , nf * 2 )
self . progressive_schedule = growing_schedule
self . growth_fade_in_per_step = 1 / growth_fade_in_steps
self . pnf = nf * 4
self . downsamples = nn . ModuleList ( [ ] )
self . upsamples = nn . ModuleList ( [ ] )
for i , step in enumerate ( growing_schedule ) :
if step > = start_step :
self . add_layer ( i + 1 )
def add_layer ( self ) :
self . downsamples . append ( DiscriminatorDownsample ( self . pnf , self . pnf ) )
self . upsamples . append ( DiscriminatorUpsample ( self . pnf , self . pnf ) )
def update_for_step ( self , step ) :
self . latest_step = step
# Add any new layers as spelled out by the schedule.
if step != 0 :
for i , s in enumerate ( self . progressive_schedule ) :
if s == step :
self . add_layer ( i + 1 )
def forward ( self , x , output_feature_vector = False ) :
x = self . conv0_0 ( x )
x = self . conv0_1 ( x )
x = self . conv1_0 ( x )
x = self . conv1_1 ( x )
base_fea = self . down_base ( x )
x = base_fea
skips = [ ]
for down in self . downsamples :
x = down ( x )
skips . append ( x )
losses = [ ]
for i , up in enumerate ( self . upsamples ) :
j = i + 1
x , loss = up ( x , skips [ - j ] )
losses . append ( loss )
# This variant averages the outputs of the U-net across the upsamples, weighting the contribution
# to the average less for newly growing levels.
_ , base_loss = self . up_base ( x , base_fea )
res = base_loss . shape [ 2 : ]
mean_weight = 1
for i , l in enumerate ( losses ) :
fade_in = 1
if self . latest_step > 0 and self . progressive_schedule [ i ] != 0 :
disc_age = self . latest_step - self . progressive_schedule [ i ]
fade_in = min ( 1 , disc_age * self . growth_fade_in_per_step )
mean_weight + = fade_in
2020-08-04 17:28:52 +00:00
base_loss + = F . interpolate ( l , size = res , mode = " bilinear " , align_corners = False ) * fade_in
2020-07-18 13:23:26 +00:00
base_loss / = mean_weight
return base_loss . view ( - 1 , 1 )
def pixgan_parameters ( self ) :
return 1 , 4