@ -3,13 +3,11 @@ from abc import abstractmethod
import torch
import torch . nn as nn
import torch . nn . init as init
import torch . nn . functional as F
import torch . nn . utils . spectral_norm as SpectralNorm
from math import sqrt
import torch . nn . init as init
from utils . util import checkpoint
import torch_intermediary as ml
import dlas . torch_intermediary as ml
from dlas . utils . util import checkpoint
def exists ( val ) :
@ -21,14 +19,14 @@ def default(val, d):
def l2norm ( t ) :
return F . normalize ( t , p = 2 , dim = - 1 )
return F . normalize ( t , p = 2 , dim = - 1 )
def ema_inplace ( moving_avg , new , decay ) :
moving_avg . data . mul_ ( decay ) . add_ ( new , alpha = ( 1 - decay ) )
moving_avg . data . mul_ ( decay ) . add_ ( new , alpha = ( 1 - decay ) )
def laplace_smoothing ( x , n_categories , eps = 1e-5 ) :
def laplace_smoothing ( x , n_categories , eps = 1e-5 ) :
return ( x + eps ) / ( x . sum ( ) + n_categories * eps )
@ -36,9 +34,9 @@ def sample_vectors(samples, num):
num_samples , device = samples . shape [ 0 ] , samples . device
if num_samples > = num :
indices = torch . randperm ( num_samples , device = device ) [ : num ]
indices = torch . randperm ( num_samples , device = device ) [ : num ]
else :
indices = torch . randint ( 0 , num_samples , ( num , ) , device = device )
indices = torch . randint ( 0 , num_samples , ( num , ) , device = device )
return samples [ indices ]
@ -239,7 +237,8 @@ class AttentionPool2d(nn.Module):
b , c , * _spatial = x . shape
x = x . reshape ( b , c , - 1 ) # NC(HW)
x = torch . cat ( [ x . mean ( dim = - 1 , keepdim = True ) , x ] , dim = - 1 ) # NC(HW+1)
x = x + self . positional_embedding [ None , : , : x . shape [ - 1 ] ] . to ( x . dtype ) # NC(HW+1)
x = x + self . positional_embedding [ None ,
: , : x . shape [ - 1 ] ] . to ( x . dtype ) # NC(HW+1)
x = self . qkv_proj ( x )
x = self . attention ( x )
x = self . c_proj ( x )
@ -296,7 +295,8 @@ class Upsample(nn.Module):
if dims == 1 :
ksize = 5
pad = 2
self . conv = conv_nd ( dims , self . channels , self . out_channels , ksize , padding = pad )
self . conv = conv_nd ( dims , self . channels ,
self . out_channels , ksize , padding = pad )
def forward ( self , x ) :
assert x . shape [ 1 ] == self . channels
@ -346,6 +346,7 @@ class cGLU(nn.Module):
"""
Gated GELU for channel - first architectures .
"""
def __init__ ( self , dim_in , dim_out = None ) :
super ( ) . __init__ ( )
dim_out = dim_in if dim_out is None else dim_out
@ -395,7 +396,8 @@ class ResBlock(nn.Module):
self . in_layers = nn . Sequential (
normalization ( channels ) ,
nn . SiLU ( ) ,
conv_nd ( dims , channels , self . out_channels , kernel_size , padding = padding ) ,
conv_nd ( dims , channels , self . out_channels ,
kernel_size , padding = padding ) ,
)
self . updown = up or down
@ -414,7 +416,8 @@ class ResBlock(nn.Module):
nn . SiLU ( ) ,
nn . Dropout ( p = dropout ) ,
zero_module (
conv_nd ( dims , self . out_channels , self . out_channels , kernel_size , padding = padding )
conv_nd ( dims , self . out_channels , self . out_channels ,
kernel_size , padding = padding )
) ,
)
@ -425,7 +428,8 @@ class ResBlock(nn.Module):
dims , channels , self . out_channels , kernel_size , padding = padding
)
else :
self . skip_connection = conv_nd ( dims , channels , self . out_channels , 1 )
self . skip_connection = conv_nd (
dims , channels , self . out_channels , 1 )
def forward ( self , x ) :
"""
@ -466,10 +470,10 @@ def build_local_attention_mask(n, l, fixed_region=0):
A mask that can be applied to AttentionBlock to achieve local attention .
"""
assert l * 2 < n , f ' Local context must be less than global context. { l } , { n } '
o = torch . arange ( 0 , n )
c = o . unsqueeze ( - 1 ) . repeat ( 1 , n )
r = o . unsqueeze ( 0 ) . repeat ( n , 1 )
localized = ( ( - ( r - c ) . abs ( ) ) + l ) . clamp ( 0 , l - 1 ) / ( l - 1 )
o = torch . arange ( 0 , n )
c = o . unsqueeze ( - 1 ) . repeat ( 1 , n )
r = o . unsqueeze ( 0 ) . repeat ( n , 1 )
localized = ( ( - ( r - c ) . abs ( ) ) + l ) . clamp ( 0 , l - 1 ) / ( l - 1 )
localized [ : fixed_region ] = 1
localized [ : , : fixed_region ] = 1
mask = localized > 0
@ -477,7 +481,7 @@ def build_local_attention_mask(n, l, fixed_region=0):
def test_local_attention_mask ( ) :
print ( build_local_attention_mask ( 9 , 4 , 1 ) )
print ( build_local_attention_mask ( 9 , 4 , 1 ) )
class RelativeQKBias ( nn . Module ) :
@ -487,17 +491,18 @@ class RelativeQKBias(nn.Module):
If symmetric = False , a different bias is applied to each side of the input element , otherwise the bias is symmetric .
"""
def __init__ ( self , l , max_positions = 4000 , symmetric = True ) :
super ( ) . __init__ ( )
if symmetric :
self . emb = nn . Parameter ( torch . randn ( l + 1 ) * .01 )
o = torch . arange ( 0 , max_positions )
c = o . unsqueeze ( - 1 ) . repeat ( 1 , max_positions )
r = o . unsqueeze ( 0 ) . repeat ( max_positions , 1 )
M = ( ( - ( r - c ) . abs ( ) ) + l ) . clamp ( 0 , l )
o = torch . arange ( 0 , max_positions )
c = o . unsqueeze ( - 1 ) . repeat ( 1 , max_positions )
r = o . unsqueeze ( 0 ) . repeat ( max_positions , 1 )
M = ( ( - ( r - c ) . abs ( ) ) + l ) . clamp ( 0 , l )
else :
self . emb = nn . Parameter ( torch . randn ( l * 2 + 2 ) * .01 )
a = torch . arange ( 0 , max_positions )
a = torch . arange ( 0 , max_positions )
c = a . unsqueeze ( - 1 ) - a
m = ( c > = - l ) . logical_and ( c < = l )
M = ( l + c + 1 ) * m
@ -508,7 +513,7 @@ class RelativeQKBias(nn.Module):
# return self.emb[self.M[:n, :n]].view(1,n,n)
# However, indexing operations like this have horrible efficiency on GPUs: https://github.com/pytorch/pytorch/issues/15245
# So, enter this horrible, equivalent mess:
return torch . gather ( self . emb . unsqueeze ( - 1 ) . repeat ( 1 , n ) , 0 , self . M [ : n , : n ] ) . view ( 1 , n , n )
return torch . gather ( self . emb . unsqueeze ( - 1 ) . repeat ( 1 , n ) , 0 , self . M [ : n , : n ] ) . view ( 1 , n , n )
class AttentionBlock ( nn . Module ) :
@ -550,7 +555,8 @@ class AttentionBlock(nn.Module):
# split heads before split qkv
self . attention = QKVAttentionLegacy ( self . num_heads )
self . x_proj = nn . Identity ( ) if out_channels == channels else conv_nd ( 1 , channels , out_channels , 1 )
self . x_proj = nn . Identity ( ) if out_channels == channels else conv_nd (
1 , channels , out_channels , 1 )
self . proj_out = zero_module ( conv_nd ( 1 , out_channels , out_channels , 1 ) )
def forward ( self , x , mask = None , qk_bias = None ) :
@ -572,7 +578,7 @@ class AttentionBlock(nn.Module):
b , c , * spatial = x . shape
if mask is not None :
if len ( mask . shape ) == 2 :
mask = mask . unsqueeze ( 0 ) . repeat ( x . shape [ 0 ] , 1 , 1 )
mask = mask . unsqueeze ( 0 ) . repeat ( x . shape [ 0 ] , 1 , 1 )
if mask . shape [ 1 ] != x . shape [ - 1 ] :
mask = mask [ : , : x . shape [ - 1 ] , : x . shape [ - 1 ] ]
@ -606,7 +612,8 @@ class QKVAttentionLegacy(nn.Module):
bs , width , length = qkv . shape
assert width % ( 3 * self . n_heads ) == 0
ch = width / / ( 3 * self . n_heads )
q , k , v = qkv . reshape ( bs * self . n_heads , ch * 3 , length ) . split ( ch , dim = 1 )
q , k , v = qkv . reshape ( bs * self . n_heads , ch * 3 ,
length ) . split ( ch , dim = 1 )
scale = 1 / math . sqrt ( math . sqrt ( ch ) )
weight = torch . einsum (
" bct,bcs->bts " , q * scale , k * scale
@ -651,7 +658,8 @@ class QKVAttention(nn.Module):
mask = mask . repeat ( self . n_heads , 1 , 1 )
weight [ mask . logical_not ( ) ] = - torch . inf
weight = torch . softmax ( weight . float ( ) , dim = - 1 ) . type ( weight . dtype )
a = torch . einsum ( " bts,bcs->bct " , weight , v . reshape ( bs * self . n_heads , ch , length ) )
a = torch . einsum ( " bts,bcs->bct " , weight ,
v . reshape ( bs * self . n_heads , ch , length ) )
return a . reshape ( bs , - 1 , length )
@ -678,7 +686,8 @@ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
vgrid_x = 2.0 * vgrid [ : , : , : , 0 ] / max ( W - 1 , 1 ) - 1.0
vgrid_y = 2.0 * vgrid [ : , : , : , 1 ] / max ( H - 1 , 1 ) - 1.0
vgrid_scaled = torch . stack ( ( vgrid_x , vgrid_y ) , dim = 3 )
output = F . grid_sample ( x , vgrid_scaled , mode = interp_mode , padding_mode = padding_mode )
output = F . grid_sample (
x , vgrid_scaled , mode = interp_mode , padding_mode = padding_mode )
return output
@ -690,7 +699,8 @@ class PixelUnshuffle(nn.Module):
def forward ( self , x ) :
( b , f , w , h ) = x . shape
x = x . contiguous ( ) . view ( b , f , w / / self . r , self . r , h / / self . r , self . r )
x = x . permute ( 0 , 1 , 3 , 5 , 2 , 4 ) . contiguous ( ) . view ( b , f * ( self . r * * 2 ) , w / / self . r , h / / self . r )
x = x . permute ( 0 , 1 , 3 , 5 , 2 , 4 ) . contiguous ( ) . view (
b , f * ( self . r * * 2 ) , w / / self . r , h / / self . r )
return x
@ -704,6 +714,8 @@ def silu(input):
# create a class wrapper from PyTorch nn.Module, so
# the function now can be easily used in models
class SiLU ( nn . Module ) :
'''
Applies the Sigmoid Linear Unit ( SiLU ) function element - wise :
@ -720,11 +732,12 @@ class SiLU(nn.Module):
>> > input = torch . randn ( 2 )
>> > output = m ( input )
'''
def __init__ ( self ) :
'''
Init method .
'''
super ( ) . __init__ ( ) # init the base class
super ( ) . __init__ ( ) # init the base class
def forward ( self , input ) :
'''
@ -735,12 +748,15 @@ class SiLU(nn.Module):
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
kernel sizes . '''
class ConvBnRelu ( nn . Module ) :
def __init__ ( self , filters_in , filters_out , kernel_size = 3 , stride = 1 , activation = True , norm = True , bias = True ) :
super ( ConvBnRelu , self ) . __init__ ( )
padding_map = { 1 : 0 , 3 : 1 , 5 : 2 , 7 : 3 }
assert kernel_size in padding_map . keys ( )
self . conv = nn . Conv2d ( filters_in , filters_out , kernel_size , stride , padding_map [ kernel_size ] , bias = bias )
self . conv = nn . Conv2d ( filters_in , filters_out , kernel_size ,
stride , padding_map [ kernel_size ] , bias = bias )
if norm :
self . bn = nn . BatchNorm2d ( filters_out )
else :
@ -753,7 +769,8 @@ class ConvBnRelu(nn.Module):
# Init params.
for m in self . modules ( ) :
if isinstance ( m , nn . Conv2d ) :
nn . init . kaiming_normal_ ( m . weight , mode = ' fan_out ' , nonlinearity = ' relu ' if self . relu else ' linear ' )
nn . init . kaiming_normal_ (
m . weight , mode = ' fan_out ' , nonlinearity = ' relu ' if self . relu else ' linear ' )
elif isinstance ( m , ( nn . BatchNorm2d , nn . GroupNorm ) ) :
nn . init . constant_ ( m . weight , 1 )
nn . init . constant_ ( m . bias , 0 )
@ -770,12 +787,15 @@ class ConvBnRelu(nn.Module):
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
kernel sizes . '''
class ConvBnSilu ( nn . Module ) :
def __init__ ( self , filters_in , filters_out , kernel_size = 3 , stride = 1 , activation = True , norm = True , bias = True , weight_init_factor = 1 ) :
super ( ConvBnSilu , self ) . __init__ ( )
padding_map = { 1 : 0 , 3 : 1 , 5 : 2 , 7 : 3 }
assert kernel_size in padding_map . keys ( )
self . conv = nn . Conv2d ( filters_in , filters_out , kernel_size , stride , padding_map [ kernel_size ] , bias = bias )
self . conv = nn . Conv2d ( filters_in , filters_out , kernel_size ,
stride , padding_map [ kernel_size ] , bias = bias )
if norm :
self . bn = nn . BatchNorm2d ( filters_out )
else :
@ -788,7 +808,8 @@ class ConvBnSilu(nn.Module):
# Init params.
for m in self . modules ( ) :
if isinstance ( m , nn . Conv2d ) :
nn . init . kaiming_normal_ ( m . weight , mode = ' fan_out ' , nonlinearity = ' relu ' if self . silu else ' linear ' )
nn . init . kaiming_normal_ (
m . weight , mode = ' fan_out ' , nonlinearity = ' relu ' if self . silu else ' linear ' )
m . weight . data * = weight_init_factor
if m . bias is not None :
m . bias . data . zero_ ( )
@ -808,12 +829,15 @@ class ConvBnSilu(nn.Module):
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
kernel sizes . '''
class ConvBnLelu ( nn . Module ) :
def __init__ ( self , filters_in , filters_out , kernel_size = 3 , stride = 1 , activation = True , norm = True , bias = True , weight_init_factor = 1 ) :
super ( ConvBnLelu , self ) . __init__ ( )
padding_map = { 1 : 0 , 3 : 1 , 5 : 2 , 7 : 3 }
assert kernel_size in padding_map . keys ( )
self . conv = nn . Conv2d ( filters_in , filters_out , kernel_size , stride , padding_map [ kernel_size ] , bias = bias )
self . conv = nn . Conv2d ( filters_in , filters_out , kernel_size ,
stride , padding_map [ kernel_size ] , bias = bias )
if norm :
self . bn = nn . BatchNorm2d ( filters_out )
else :
@ -847,12 +871,15 @@ class ConvBnLelu(nn.Module):
''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard
kernel sizes . '''
class ConvGnLelu ( nn . Module ) :
def __init__ ( self , filters_in , filters_out , kernel_size = 3 , stride = 1 , activation = True , norm = True , bias = True , num_groups = 8 , weight_init_factor = 1 ) :
super ( ConvGnLelu , self ) . __init__ ( )
padding_map = { 1 : 0 , 3 : 1 , 5 : 2 , 7 : 3 }
assert kernel_size in padding_map . keys ( )
self . conv = nn . Conv2d ( filters_in , filters_out , kernel_size , stride , padding_map [ kernel_size ] , bias = bias )
self . conv = nn . Conv2d ( filters_in , filters_out , kernel_size ,
stride , padding_map [ kernel_size ] , bias = bias )
if norm :
self . gn = nn . GroupNorm ( num_groups , filters_out )
else :
@ -886,12 +913,15 @@ class ConvGnLelu(nn.Module):
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
kernel sizes . '''
class ConvGnSilu ( nn . Module ) :
def __init__ ( self , filters_in , filters_out , kernel_size = 3 , stride = 1 , activation = True , norm = True , bias = True , num_groups = 8 , weight_init_factor = 1 , convnd = nn . Conv2d ) :
super ( ConvGnSilu , self ) . __init__ ( )
padding_map = { 1 : 0 , 3 : 1 , 5 : 2 , 7 : 3 }
assert kernel_size in padding_map . keys ( )
self . conv = convnd ( filters_in , filters_out , kernel_size , stride , padding_map [ kernel_size ] , bias = bias )
self . conv = convnd ( filters_in , filters_out , kernel_size ,
stride , padding_map [ kernel_size ] , bias = bias )
if norm :
self . gn = nn . GroupNorm ( num_groups , filters_out )
else :
@ -904,7 +934,8 @@ class ConvGnSilu(nn.Module):
# Init params.
for m in self . modules ( ) :
if isinstance ( m , convnd ) :
nn . init . kaiming_normal_ ( m . weight , mode = ' fan_out ' , nonlinearity = ' relu ' if self . silu else ' linear ' )
nn . init . kaiming_normal_ (
m . weight , mode = ' fan_out ' , nonlinearity = ' relu ' if self . silu else ' linear ' )
m . weight . data * = weight_init_factor
if m . bias is not None :
m . bias . data . zero_ ( )
@ -924,12 +955,15 @@ class ConvGnSilu(nn.Module):
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
kernel sizes . '''
class ConvBnRelu ( nn . Module ) :
def __init__ ( self , filters_in , filters_out , kernel_size = 3 , stride = 1 , activation = True , norm = True , bias = True , weight_init_factor = 1 ) :
super ( ConvBnRelu , self ) . __init__ ( )
padding_map = { 1 : 0 , 3 : 1 , 5 : 2 , 7 : 3 }
assert kernel_size in padding_map . keys ( )
self . conv = nn . Conv2d ( filters_in , filters_out , kernel_size , stride , padding_map [ kernel_size ] , bias = bias )
self . conv = nn . Conv2d ( filters_in , filters_out , kernel_size ,
stride , padding_map [ kernel_size ] , bias = bias )
if norm :
self . bn = nn . BatchNorm2d ( filters_out )
else :
@ -942,7 +976,8 @@ class ConvBnRelu(nn.Module):
# Init params.
for m in self . modules ( ) :
if isinstance ( m , nn . Conv2d ) :
nn . init . kaiming_normal_ ( m . weight , mode = ' fan_out ' , nonlinearity = ' relu ' if self . relu else ' linear ' )
nn . init . kaiming_normal_ (
m . weight , mode = ' fan_out ' , nonlinearity = ' relu ' if self . relu else ' linear ' )
m . weight . data * = weight_init_factor
if m . bias is not None :
m . bias . data . zero_ ( )
@ -969,7 +1004,8 @@ class MultiConvBlock(nn.Module):
self . bnconvs = nn . ModuleList ( [ ConvBnLelu ( filters_in , filters_mid , kernel_size , norm = norm , bias = False , weight_init_factor = weight_init_factor ) ] +
[ ConvBnLelu ( filters_mid , filters_mid , kernel_size , norm = norm , bias = False , weight_init_factor = weight_init_factor ) for i in range ( depth - 2 ) ] +
[ ConvBnLelu ( filters_mid , filters_out , kernel_size , activation = False , norm = False , bias = False , weight_init_factor = weight_init_factor ) ] )
self . scale = nn . Parameter ( torch . full ( ( 1 , ) , fill_value = scale_init , dtype = torch . float ) )
self . scale = nn . Parameter ( torch . full (
( 1 , ) , fill_value = scale_init , dtype = torch . float ) )
self . bias = nn . Parameter ( torch . zeros ( 1 ) )
def forward ( self , x , noise = None ) :
@ -988,10 +1024,14 @@ class ExpansionBlock(nn.Module):
super ( ExpansionBlock , self ) . __init__ ( )
if filters_out is None :
filters_out = filters_in / / 2
self . decimate = block ( filters_in , filters_out , kernel_size = 1 , bias = False , activation = False , norm = True )
self . process_passthrough = block ( filters_out , filters_out , kernel_size = 3 , bias = True , activation = False , norm = True )
self . conjoin = block ( filters_out * 2 , filters_out , kernel_size = 3 , bias = False , activation = True , norm = False )
self . process = block ( filters_out , filters_out , kernel_size = 3 , bias = False , activation = True , norm = True )
self . decimate = block (
filters_in , filters_out , kernel_size = 1 , bias = False , activation = False , norm = True )
self . process_passthrough = block (
filters_out , filters_out , kernel_size = 3 , bias = True , activation = False , norm = True )
self . conjoin = block ( filters_out * 2 , filters_out ,
kernel_size = 3 , bias = False , activation = True , norm = False )
self . process = block ( filters_out , filters_out ,
kernel_size = 3 , bias = False , activation = True , norm = True )
# input is the feature signal with shape (b, f, w, h)
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
@ -1012,10 +1052,14 @@ class ExpansionBlock2(nn.Module):
super ( ExpansionBlock2 , self ) . __init__ ( )
if filters_out is None :
filters_out = filters_in / / 2
self . decimate = block ( filters_in , filters_out , kernel_size = 1 , bias = False , activation = False , norm = True )
self . process_passthrough = block ( filters_out , filters_out , kernel_size = 3 , bias = True , activation = False , norm = True )
self . conjoin = block ( filters_out * 2 , filters_out * 2 , kernel_size = 3 , bias = False , activation = True , norm = False )
self . reduce = block ( filters_out * 2 , filters_out , kernel_size = 3 , bias = False , activation = True , norm = True )
self . decimate = block (
filters_in , filters_out , kernel_size = 1 , bias = False , activation = False , norm = True )
self . process_passthrough = block (
filters_out , filters_out , kernel_size = 3 , bias = True , activation = False , norm = True )
self . conjoin = block ( filters_out * 2 , filters_out * 2 ,
kernel_size = 3 , bias = False , activation = True , norm = False )
self . reduce = block ( filters_out * 2 , filters_out ,
kernel_size = 3 , bias = False , activation = True , norm = True )
# input is the feature signal with shape (b, f, w, h)
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
@ -1036,8 +1080,10 @@ class ConjoinBlock(nn.Module):
filters_out = filters_in
if filters_pt is None :
filters_pt = filters_in
self . process = block ( filters_in + filters_pt , filters_in + filters_pt , kernel_size = 3 , bias = False , activation = True , norm = norm )
self . decimate = block ( filters_in + filters_pt , filters_out , kernel_size = 1 , bias = False , activation = False , norm = norm )
self . process = block ( filters_in + filters_pt , filters_in + filters_pt ,
kernel_size = 3 , bias = False , activation = True , norm = norm )
self . decimate = block ( filters_in + filters_pt , filters_out ,
kernel_size = 1 , bias = False , activation = False , norm = norm )
def forward ( self , input , passthrough ) :
x = torch . cat ( [ input , passthrough ] , dim = 1 )
@ -1053,7 +1099,8 @@ class ReferenceJoinBlock(nn.Module):
scale_init = residual_weight_init_factor , norm = False ,
weight_init_factor = residual_weight_init_factor )
if join :
self . join_conv = block ( nf , nf , kernel_size = kernel_size , norm = final_norm , bias = False , activation = True )
self . join_conv = block (
nf , nf , kernel_size = kernel_size , norm = final_norm , bias = False , activation = True )
else :
self . join_conv = None
@ -1070,7 +1117,8 @@ class ReferenceJoinBlock(nn.Module):
class UpconvBlock ( nn . Module ) :
def __init__ ( self , filters_in , filters_out = None , block = ConvGnSilu , norm = True , activation = True , bias = False ) :
super ( UpconvBlock , self ) . __init__ ( )
self . process = block ( filters_in , filters_out , kernel_size = 3 , bias = bias , activation = activation , norm = norm )
self . process = block ( filters_in , filters_out , kernel_size = 3 ,
bias = bias , activation = activation , norm = norm )
def forward ( self , x ) :
x = F . interpolate ( x , scale_factor = 2 , mode = " nearest " )
@ -1083,21 +1131,29 @@ class FinalUpsampleBlock2x(nn.Module):
super ( FinalUpsampleBlock2x , self ) . __init__ ( )
if scale == 2 :
self . chain = nn . Sequential ( block ( nf , nf , kernel_size = 3 , norm = False , activation = True , bias = True ) ,
UpconvBlock ( nf , nf / / 2 , block = block , norm = False , activation = True , bias = True ) ,
block ( nf / / 2 , nf / / 2 , kernel_size = 3 , norm = False , activation = False , bias = True ) ,
UpconvBlock (
nf , nf / / 2 , block = block , norm = False , activation = True , bias = True ) ,
block ( nf / / 2 , nf / / 2 , kernel_size = 3 ,
norm = False , activation = False , bias = True ) ,
block ( nf / / 2 , out_nc , kernel_size = 3 , norm = False , activation = False , bias = False ) )
else :
self . chain = nn . Sequential ( block ( nf , nf , kernel_size = 3 , norm = False , activation = True , bias = True ) ,
UpconvBlock ( nf , nf , block = block , norm = False , activation = True , bias = True ) ,
block ( nf , nf , kernel_size = 3 , norm = False , activation = False , bias = True ) ,
UpconvBlock ( nf , nf / / 2 , block = block , norm = False , activation = True , bias = True ) ,
block ( nf / / 2 , nf / / 2 , kernel_size = 3 , norm = False , activation = False , bias = True ) ,
UpconvBlock (
nf , nf , block = block , norm = False , activation = True , bias = True ) ,
block ( nf , nf , kernel_size = 3 , norm = False ,
activation = False , bias = True ) ,
UpconvBlock (
nf , nf / / 2 , block = block , norm = False , activation = True , bias = True ) ,
block ( nf / / 2 , nf / / 2 , kernel_size = 3 ,
norm = False , activation = False , bias = True ) ,
block ( nf / / 2 , out_nc , kernel_size = 3 , norm = False , activation = False , bias = False ) )
def forward ( self , x ) :
return self . chain ( x )
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
def gather_2d ( input , index ) :
b , c , h , w = input . shape
nodim = input . view ( b , c , h * w )