2020-12-18 23:04:19 +00:00
import math
import torch . nn as nn
import torch
from models . RRDBNet_arch import RRDB
from models . arch_util import ConvGnLelu
# Produces a convolutional feature (`f`) and a reduced feature map with double the filters.
from models . glean . stylegan2_latent_bank import Stylegan2LatentBank
from models . stylegan . stylegan2_rosinality import EqualLinear
from utils . util import checkpoint , sequential_checkpoint
class GleanEncoderBlock ( nn . Module ) :
def __init__ ( self , nf ) :
super ( ) . __init__ ( )
self . structural_latent_conv = ConvGnLelu ( nf , nf , kernel_size = 1 , activation = False , norm = False , bias = True )
self . process = nn . Sequential (
ConvGnLelu ( nf , nf * 2 , kernel_size = 3 , stride = 2 , activation = True , norm = False , bias = False ) ,
ConvGnLelu ( nf * 2 , nf * 2 , kernel_size = 3 , activation = True , norm = False , bias = False )
)
def forward ( self , x ) :
structural_latent = self . structural_latent_conv ( x )
fea = self . process ( x )
return fea , structural_latent
# Produces RRDB features, a list of convolutional features (`f` shape=[l][b,c,h,w] l=levels aka f_sub)
# and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank.
# Note that latent levels and convolutional feature levels do not necessarily match, per the paper.
class GleanEncoder ( nn . Module ) :
2020-12-19 15:26:07 +00:00
def __init__ ( self , nf , nb , reductions = 4 , latent_bank_blocks = 7 , latent_bank_latent_dim = 512 , input_dim = 32 ) :
2020-12-18 23:04:19 +00:00
super ( ) . __init__ ( )
self . initial_conv = ConvGnLelu ( 3 , nf , kernel_size = 7 , activation = False , norm = False , bias = True )
self . rrdb_blocks = nn . Sequential ( * [ RRDB ( nf ) for _ in range ( nb ) ] )
self . reducers = nn . ModuleList ( [ GleanEncoderBlock ( nf * 2 * * i ) for i in range ( reductions ) ] )
reducer_output_dim = ( input_dim / / ( 2 * * reductions ) ) * * 2
reducer_output_nf = nf * 2 * * reductions
self . latent_conv = ConvGnLelu ( reducer_output_nf , reducer_output_nf , kernel_size = 1 , activation = True , norm = False , bias = True )
2020-12-19 15:26:07 +00:00
self . latent_linear = EqualLinear ( reducer_output_dim * reducer_output_nf ,
latent_bank_latent_dim * latent_bank_blocks ,
activation = " fused_lrelu " )
self . latent_bank_blocks = latent_bank_blocks
2020-12-18 23:04:19 +00:00
def forward ( self , x ) :
fea = self . initial_conv ( x )
fea = sequential_checkpoint ( self . rrdb_blocks , len ( self . rrdb_blocks ) , fea )
rrdb_fea = fea
convolutional_features = [ ]
for reducer in self . reducers :
fea , f = checkpoint ( reducer , fea )
convolutional_features . append ( f )
latents = self . latent_conv ( fea )
2020-12-19 15:26:07 +00:00
latents = self . latent_linear ( latents . flatten ( 1 , - 1 ) ) . view ( fea . shape [ 0 ] , self . latent_bank_blocks , - 1 )
2020-12-18 23:04:19 +00:00
return rrdb_fea , convolutional_features , latents
# Produces an image by fusing the output features from the latent bank.
class GleanDecoder ( nn . Module ) :
# To determine latent_bank_filters, use the `self.channels` map for the desired input dimensions from stylegan2_rosinality.py
def __init__ ( self , nf , latent_bank_filters = [ 512 , 256 , 128 ] ) :
super ( ) . __init__ ( )
2020-12-19 15:26:07 +00:00
self . initial_conv = ConvGnLelu ( nf , nf , kernel_size = 3 , activation = True , norm = False , bias = True , weight_init_factor = .1 )
2020-12-18 23:04:19 +00:00
2020-12-19 15:26:07 +00:00
decoder_block_shuffled_dims = [ nf ] + latent_bank_filters
2020-12-18 23:04:19 +00:00
self . decoder_blocks = nn . ModuleList ( [ ConvGnLelu ( decoder_block_shuffled_dims [ i ] + latent_bank_filters [ i ] ,
latent_bank_filters [ i ] ,
2020-12-19 15:26:07 +00:00
kernel_size = 3 , bias = True , norm = False , activation = True ,
weight_init_factor = .1 )
2020-12-18 23:04:19 +00:00
for i in range ( len ( latent_bank_filters ) ) ] )
final_dim = latent_bank_filters [ - 1 ]
2020-12-19 15:26:07 +00:00
self . final_decode = ConvGnLelu ( final_dim , 3 , kernel_size = 3 , activation = False , bias = True , norm = False , weight_init_factor = .1 )
2020-12-18 23:04:19 +00:00
def forward ( self , rrdb_fea , latent_bank_fea ) :
fea = self . initial_conv ( rrdb_fea )
for i , block in enumerate ( self . decoder_blocks ) :
2020-12-19 15:26:07 +00:00
# The paper calls for PixelShuffle here, but I don't have good experience with that. It also doesn't align with the way the underlying StyleGAN works.
fea = nn . functional . interpolate ( fea , scale_factor = 2 , mode = " nearest " )
2020-12-18 23:04:19 +00:00
fea = torch . cat ( [ fea , latent_bank_fea [ i ] ] , dim = 1 )
fea = checkpoint ( block , fea )
return self . final_decode ( fea )
class GleanGenerator ( nn . Module ) :
def __init__ ( self , nf , latent_bank_pretrained_weights , latent_bank_max_dim = 1024 , gen_output_dim = 256 ,
encoder_rrdb_nb = 6 , encoder_reductions = 4 , latent_bank_latent_dim = 512 , input_dim = 32 ) :
super ( ) . __init__ ( )
self . input_dim = input_dim
2020-12-19 15:26:07 +00:00
latent_blocks = int ( math . log ( gen_output_dim , 2 ) ) # From 4x4->gen_output_dim x gen_output_dim + initial styled conv
self . encoder = GleanEncoder ( nf , encoder_rrdb_nb , reductions = encoder_reductions , latent_bank_blocks = latent_blocks ,
2020-12-18 23:04:19 +00:00
latent_bank_latent_dim = latent_bank_latent_dim , input_dim = input_dim )
decoder_blocks = int ( math . log ( gen_output_dim / input_dim , 2 ) )
latent_bank_filters_out = [ 512 , 256 , 128 ] # TODO: Use decoder_blocks to synthesize the correct value for latent_bank_filters here. The fixed defaults will work fine for testing, though.
self . latent_bank = Stylegan2LatentBank ( latent_bank_pretrained_weights , encoder_nf = nf , max_dim = latent_bank_max_dim ,
latent_dim = latent_bank_latent_dim , encoder_levels = encoder_reductions ,
decoder_levels = decoder_blocks )
self . decoder = GleanDecoder ( nf , latent_bank_filters_out )
def forward ( self , x ) :
assert self . input_dim == x . shape [ - 1 ] and self . input_dim == x . shape [ - 2 ]
rrdb_fea , conv_fea , latents = self . encoder ( x )
latent_bank_fea = self . latent_bank ( conv_fea , latents )
return self . decoder ( rrdb_fea , latent_bank_fea )