2021-06-06 19:57:22 +00:00
from abc import abstractmethod
import math
import numpy as np
import torch
import torch as th
import torch . nn as nn
import torch . nn . functional as F
import torchvision # For debugging, not actually used.
from models . diffusion . fp16_util import convert_module_to_f16 , convert_module_to_f32
from models . diffusion . nn import (
conv_nd ,
linear ,
avg_pool_nd ,
zero_module ,
normalization ,
timestep_embedding ,
)
from trainer . networks import register_model
from utils . util import checkpoint
class AttentionPool2d ( nn . Module ) :
"""
Adapted from CLIP : https : / / github . com / openai / CLIP / blob / main / clip / model . py
"""
def __init__ (
self ,
spacial_dim : int ,
embed_dim : int ,
num_heads_channels : int ,
output_dim : int = None ,
) :
super ( ) . __init__ ( )
self . positional_embedding = nn . Parameter (
th . randn ( embed_dim , spacial_dim * * 2 + 1 ) / embed_dim * * 0.5
)
self . qkv_proj = conv_nd ( 1 , embed_dim , 3 * embed_dim , 1 )
self . c_proj = conv_nd ( 1 , embed_dim , output_dim or embed_dim , 1 )
self . num_heads = embed_dim / / num_heads_channels
self . attention = QKVAttention ( self . num_heads )
def forward ( self , x ) :
b , c , * _spatial = x . shape
x = x . reshape ( b , c , - 1 ) # NC(HW)
x = th . cat ( [ x . mean ( dim = - 1 , keepdim = True ) , x ] , dim = - 1 ) # NC(HW+1)
2021-07-27 11:36:17 +00:00
x = x + self . positional_embedding [ None , : , : x . shape [ - 1 ] ] . to ( x . dtype ) # NC(HW+1)
2021-06-06 19:57:22 +00:00
x = self . qkv_proj ( x )
x = self . attention ( x )
x = self . c_proj ( x )
return x [ : , : , 0 ]
class TimestepBlock ( nn . Module ) :
"""
Any module where forward ( ) takes timestep embeddings as a second argument .
"""
@abstractmethod
def forward ( self , x , emb ) :
"""
Apply the module to ` x ` given ` emb ` timestep embeddings .
"""
class TimestepEmbedSequential ( nn . Sequential , TimestepBlock ) :
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input .
"""
def forward ( self , x , emb ) :
for layer in self :
if isinstance ( layer , TimestepBlock ) :
x = layer ( x , emb )
else :
x = layer ( x )
return x
class Upsample ( nn . Module ) :
"""
An upsampling layer with an optional convolution .
: param channels : channels in the inputs and outputs .
: param use_conv : a bool determining if a convolution is applied .
: param dims : determines if the signal is 1 D , 2 D , or 3 D . If 3 D , then
upsampling occurs in the inner - two dimensions .
"""
def __init__ ( self , channels , use_conv , dims = 2 , out_channels = None ) :
super ( ) . __init__ ( )
self . channels = channels
self . out_channels = out_channels or channels
self . use_conv = use_conv
self . dims = dims
if use_conv :
2021-07-27 11:36:17 +00:00
ksize = 3
pad = 1
if dims == 1 :
ksize = 5
pad = 2
self . conv = conv_nd ( dims , self . channels , self . out_channels , ksize , padding = pad )
2021-06-06 19:57:22 +00:00
def forward ( self , x ) :
assert x . shape [ 1 ] == self . channels
if self . dims == 3 :
x = F . interpolate (
x , ( x . shape [ 2 ] , x . shape [ 3 ] * 2 , x . shape [ 4 ] * 2 ) , mode = " nearest "
)
2021-07-27 11:36:17 +00:00
elif self . dims == 1 :
x = F . interpolate ( x , scale_factor = 4 , mode = " nearest " )
2021-06-06 19:57:22 +00:00
else :
x = F . interpolate ( x , scale_factor = 2 , mode = " nearest " )
if self . use_conv :
x = self . conv ( x )
return x
class Downsample ( nn . Module ) :
"""
A downsampling layer with an optional convolution .
: param channels : channels in the inputs and outputs .
: param use_conv : a bool determining if a convolution is applied .
: param dims : determines if the signal is 1 D , 2 D , or 3 D . If 3 D , then
downsampling occurs in the inner - two dimensions .
"""
def __init__ ( self , channels , use_conv , dims = 2 , out_channels = None ) :
super ( ) . __init__ ( )
self . channels = channels
self . out_channels = out_channels or channels
self . use_conv = use_conv
self . dims = dims
2021-07-27 11:36:17 +00:00
ksize = 3
pad = 1
if dims == 1 :
stride = 4
ksize = 5
pad = 2
elif dims == 2 :
stride = 2
else :
stride = ( 1 , 2 , 2 )
2021-06-06 19:57:22 +00:00
if use_conv :
self . op = conv_nd (
2021-07-27 11:36:17 +00:00
dims , self . channels , self . out_channels , ksize , stride = stride , padding = pad
2021-06-06 19:57:22 +00:00
)
else :
assert self . channels == self . out_channels
self . op = avg_pool_nd ( dims , kernel_size = stride , stride = stride )
def forward ( self , x ) :
assert x . shape [ 1 ] == self . channels
return self . op ( x )
class ResBlock ( TimestepBlock ) :
"""
A residual block that can optionally change the number of channels .
: param channels : the number of input channels .
: param emb_channels : the number of timestep embedding channels .
: param dropout : the rate of dropout .
: param out_channels : if specified , the number of out channels .
: param use_conv : if True and out_channels is specified , use a spatial
convolution instead of a smaller 1 x1 convolution to change the
channels in the skip connection .
: param dims : determines if the signal is 1 D , 2 D , or 3 D .
: param up : if True , use this block for upsampling .
: param down : if True , use this block for downsampling .
"""
def __init__ (
self ,
channels ,
emb_channels ,
dropout ,
out_channels = None ,
use_conv = False ,
use_scale_shift_norm = False ,
dims = 2 ,
up = False ,
down = False ,
2021-08-31 20:38:33 +00:00
kernel_size = 3 ,
2021-06-06 19:57:22 +00:00
) :
super ( ) . __init__ ( )
self . channels = channels
self . emb_channels = emb_channels
self . dropout = dropout
self . out_channels = out_channels or channels
self . use_conv = use_conv
self . use_scale_shift_norm = use_scale_shift_norm
2021-09-01 14:33:46 +00:00
padding = 1 if kernel_size == 3 else 2
2021-06-06 19:57:22 +00:00
self . in_layers = nn . Sequential (
normalization ( channels ) ,
nn . SiLU ( ) ,
2021-08-31 20:38:33 +00:00
conv_nd ( dims , channels , self . out_channels , kernel_size , padding = padding ) ,
2021-06-06 19:57:22 +00:00
)
self . updown = up or down
if up :
self . h_upd = Upsample ( channels , False , dims )
self . x_upd = Upsample ( channels , False , dims )
elif down :
self . h_upd = Downsample ( channels , False , dims )
self . x_upd = Downsample ( channels , False , dims )
else :
self . h_upd = self . x_upd = nn . Identity ( )
self . emb_layers = nn . Sequential (
nn . SiLU ( ) ,
linear (
emb_channels ,
2 * self . out_channels if use_scale_shift_norm else self . out_channels ,
) ,
)
self . out_layers = nn . Sequential (
normalization ( self . out_channels ) ,
nn . SiLU ( ) ,
nn . Dropout ( p = dropout ) ,
zero_module (
2021-08-31 20:38:33 +00:00
conv_nd ( dims , self . out_channels , self . out_channels , kernel_size , padding = padding )
2021-06-06 19:57:22 +00:00
) ,
)
if self . out_channels == channels :
self . skip_connection = nn . Identity ( )
elif use_conv :
self . skip_connection = conv_nd (
2021-08-31 20:38:33 +00:00
dims , channels , self . out_channels , kernel_size , padding = padding
2021-06-06 19:57:22 +00:00
)
else :
self . skip_connection = conv_nd ( dims , channels , self . out_channels , 1 )
def forward ( self , x , emb ) :
"""
Apply the block to a Tensor , conditioned on a timestep embedding .
: param x : an [ N x C x . . . ] Tensor of features .
: param emb : an [ N x emb_channels ] Tensor of timestep embeddings .
: return : an [ N x C x . . . ] Tensor of outputs .
"""
return checkpoint (
self . _forward , x , emb
)
def _forward ( self , x , emb ) :
if self . updown :
in_rest , in_conv = self . in_layers [ : - 1 ] , self . in_layers [ - 1 ]
h = in_rest ( x )
h = self . h_upd ( h )
x = self . x_upd ( x )
h = in_conv ( h )
else :
h = self . in_layers ( x )
emb_out = self . emb_layers ( emb ) . type ( h . dtype )
while len ( emb_out . shape ) < len ( h . shape ) :
emb_out = emb_out [ . . . , None ]
if self . use_scale_shift_norm :
out_norm , out_rest = self . out_layers [ 0 ] , self . out_layers [ 1 : ]
scale , shift = th . chunk ( emb_out , 2 , dim = 1 )
h = out_norm ( h ) * ( 1 + scale ) + shift
h = out_rest ( h )
else :
h = h + emb_out
h = self . out_layers ( h )
return self . skip_connection ( x ) + h
class AttentionBlock ( nn . Module ) :
"""
An attention block that allows spatial positions to attend to each other .
Originally ported from here , but adapted to the N - d case .
https : / / github . com / hojonathanho / diffusion / blob / 1e0 dceb3b3495bbe19116a5e1b3596cd0706c543 / diffusion_tf / models / unet . py #L66.
"""
def __init__ (
self ,
channels ,
num_heads = 1 ,
num_head_channels = - 1 ,
use_new_attention_order = False ,
) :
super ( ) . __init__ ( )
self . channels = channels
if num_head_channels == - 1 :
self . num_heads = num_heads
else :
assert (
channels % num_head_channels == 0
) , f " q,k,v channels { channels } is not divisible by num_head_channels { num_head_channels } "
self . num_heads = channels / / num_head_channels
self . norm = normalization ( channels )
self . qkv = conv_nd ( 1 , channels , channels * 3 , 1 )
if use_new_attention_order :
# split qkv before split heads
self . attention = QKVAttention ( self . num_heads )
else :
# split heads before split qkv
self . attention = QKVAttentionLegacy ( self . num_heads )
self . proj_out = zero_module ( conv_nd ( 1 , channels , channels , 1 ) )
def forward ( self , x ) :
return checkpoint ( self . _forward , x )
def _forward ( self , x ) :
b , c , * spatial = x . shape
x = x . reshape ( b , c , - 1 )
qkv = self . qkv ( self . norm ( x ) )
h = self . attention ( qkv )
h = self . proj_out ( h )
return ( x + h ) . reshape ( b , c , * spatial )
def count_flops_attn ( model , _x , y ) :
"""
A counter for the ` thop ` package to count the operations in an
attention operation .
Meant to be used like :
macs , params = thop . profile (
model ,
inputs = ( inputs , timestamps ) ,
custom_ops = { QKVAttention : QKVAttention . count_flops } ,
)
"""
b , c , * spatial = y [ 0 ] . shape
num_spatial = int ( np . prod ( spatial ) )
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * ( num_spatial * * 2 ) * c
model . total_ops + = th . DoubleTensor ( [ matmul_ops ] )
class QKVAttentionLegacy ( nn . Module ) :
"""
A module which performs QKV attention . Matches legacy QKVAttention + input / ouput heads shaping
"""
def __init__ ( self , n_heads ) :
super ( ) . __init__ ( )
self . n_heads = n_heads
def forward ( self , qkv ) :
"""
Apply QKV attention .
: param qkv : an [ N x ( H * 3 * C ) x T ] tensor of Qs , Ks , and Vs .
: return : an [ N x ( H * C ) x T ] tensor after attention .
"""
bs , width , length = qkv . shape
assert width % ( 3 * self . n_heads ) == 0
ch = width / / ( 3 * self . n_heads )
q , k , v = qkv . reshape ( bs * self . n_heads , ch * 3 , length ) . split ( ch , dim = 1 )
scale = 1 / math . sqrt ( math . sqrt ( ch ) )
weight = th . einsum (
" bct,bcs->bts " , q * scale , k * scale
) # More stable with f16 than dividing afterwards
weight = th . softmax ( weight . float ( ) , dim = - 1 ) . type ( weight . dtype )
a = th . einsum ( " bts,bcs->bct " , weight , v )
return a . reshape ( bs , - 1 , length )
@staticmethod
def count_flops ( model , _x , y ) :
return count_flops_attn ( model , _x , y )
class QKVAttention ( nn . Module ) :
"""
A module which performs QKV attention and splits in a different order .
"""
def __init__ ( self , n_heads ) :
super ( ) . __init__ ( )
self . n_heads = n_heads
def forward ( self , qkv ) :
"""
Apply QKV attention .
: param qkv : an [ N x ( 3 * H * C ) x T ] tensor of Qs , Ks , and Vs .
: return : an [ N x ( H * C ) x T ] tensor after attention .
"""
bs , width , length = qkv . shape
assert width % ( 3 * self . n_heads ) == 0
ch = width / / ( 3 * self . n_heads )
q , k , v = qkv . chunk ( 3 , dim = 1 )
scale = 1 / math . sqrt ( math . sqrt ( ch ) )
weight = th . einsum (
" bct,bcs->bts " ,
( q * scale ) . view ( bs * self . n_heads , ch , length ) ,
( k * scale ) . view ( bs * self . n_heads , ch , length ) ,
) # More stable with f16 than dividing afterwards
weight = th . softmax ( weight . float ( ) , dim = - 1 ) . type ( weight . dtype )
a = th . einsum ( " bts,bcs->bct " , weight , v . reshape ( bs * self . n_heads , ch , length ) )
return a . reshape ( bs , - 1 , length )
@staticmethod
def count_flops ( model , _x , y ) :
return count_flops_attn ( model , _x , y )
class UNetModel ( nn . Module ) :
"""
The full UNet model with attention and timestep embedding .
: param in_channels : channels in the input Tensor .
: param model_channels : base channel count for the model .
: param out_channels : channels in the output Tensor .
: param num_res_blocks : number of residual blocks per downsample .
: param attention_resolutions : a collection of downsample rates at which
attention will take place . May be a set , list , or tuple .
For example , if this contains 4 , then at 4 x downsampling , attention
will be used .
: param dropout : the dropout probability .
: param channel_mult : channel multiplier for each level of the UNet .
: param conv_resample : if True , use learned convolutions for upsampling and
downsampling .
: param dims : determines if the signal is 1 D , 2 D , or 3 D .
: param num_classes : if specified ( as an int ) , then this model will be
class - conditional with ` num_classes ` classes .
: param num_heads : the number of attention heads in each attention layer .
: param num_heads_channels : if specified , ignore num_heads and instead use
a fixed channel width per attention head .
: param num_heads_upsample : works with num_heads to set a different number
of heads for upsampling . Deprecated .
: param use_scale_shift_norm : use a FiLM - like conditioning mechanism .
: param resblock_updown : use residual blocks for up / downsampling .
: param use_new_attention_order : use a different attention pattern for potentially
increased efficiency .
"""
def __init__ (
self ,
image_size ,
in_channels ,
model_channels ,
out_channels ,
num_res_blocks ,
attention_resolutions ,
dropout = 0 ,
channel_mult = ( 1 , 2 , 4 , 8 ) ,
conv_resample = True ,
dims = 2 ,
num_classes = None ,
use_fp16 = False ,
num_heads = 1 ,
num_head_channels = - 1 ,
num_heads_upsample = - 1 ,
use_scale_shift_norm = False ,
resblock_updown = False ,
use_new_attention_order = False ,
2021-07-26 22:27:31 +00:00
use_raw_y_as_embedding = False ,
2021-06-06 19:57:22 +00:00
) :
super ( ) . __init__ ( )
if num_heads_upsample == - 1 :
num_heads_upsample = num_heads
self . image_size = image_size
self . in_channels = in_channels
self . model_channels = model_channels
self . out_channels = out_channels
self . num_res_blocks = num_res_blocks
self . attention_resolutions = attention_resolutions
self . dropout = dropout
self . channel_mult = channel_mult
self . conv_resample = conv_resample
self . num_classes = num_classes
self . dtype = th . float16 if use_fp16 else th . float32
self . num_heads = num_heads
self . num_head_channels = num_head_channels
self . num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4
self . time_embed = nn . Sequential (
linear ( model_channels , time_embed_dim ) ,
nn . SiLU ( ) ,
linear ( time_embed_dim , time_embed_dim ) ,
)
if self . num_classes is not None :
self . label_emb = nn . Embedding ( num_classes , time_embed_dim )
2021-07-26 22:27:31 +00:00
self . use_raw_y_as_embedding = use_raw_y_as_embedding
assert ( self . num_classes is not None ) != use_raw_y_as_embedding # These are mutually-exclusive.
2021-06-06 19:57:22 +00:00
self . input_blocks = nn . ModuleList (
[
TimestepEmbedSequential (
conv_nd ( dims , in_channels , model_channels , 3 , padding = 1 )
)
]
)
self . _feature_size = model_channels
input_block_chans = [ model_channels ]
ch = model_channels
ds = 1
for level , mult in enumerate ( channel_mult ) :
for _ in range ( num_res_blocks ) :
layers = [
ResBlock (
ch ,
time_embed_dim ,
dropout ,
out_channels = mult * model_channels ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
)
]
ch = mult * model_channels
if ds in attention_resolutions :
layers . append (
AttentionBlock (
ch ,
num_heads = num_heads ,
num_head_channels = num_head_channels ,
use_new_attention_order = use_new_attention_order ,
)
)
self . input_blocks . append ( TimestepEmbedSequential ( * layers ) )
self . _feature_size + = ch
input_block_chans . append ( ch )
if level != len ( channel_mult ) - 1 :
out_ch = ch
self . input_blocks . append (
TimestepEmbedSequential (
ResBlock (
ch ,
time_embed_dim ,
dropout ,
out_channels = out_ch ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
down = True ,
)
if resblock_updown
else Downsample (
ch , conv_resample , dims = dims , out_channels = out_ch
)
)
)
ch = out_ch
input_block_chans . append ( ch )
ds * = 2
self . _feature_size + = ch
self . middle_block = TimestepEmbedSequential (
ResBlock (
ch ,
time_embed_dim ,
dropout ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
) ,
AttentionBlock (
ch ,
num_heads = num_heads ,
num_head_channels = num_head_channels ,
use_new_attention_order = use_new_attention_order ,
) ,
ResBlock (
ch ,
time_embed_dim ,
dropout ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
) ,
)
self . _feature_size + = ch
self . output_blocks = nn . ModuleList ( [ ] )
for level , mult in list ( enumerate ( channel_mult ) ) [ : : - 1 ] :
for i in range ( num_res_blocks + 1 ) :
ich = input_block_chans . pop ( )
layers = [
ResBlock (
ch + ich ,
time_embed_dim ,
dropout ,
out_channels = model_channels * mult ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
)
]
ch = model_channels * mult
if ds in attention_resolutions :
layers . append (
AttentionBlock (
ch ,
num_heads = num_heads_upsample ,
num_head_channels = num_head_channels ,
use_new_attention_order = use_new_attention_order ,
)
)
if level and i == num_res_blocks :
out_ch = ch
layers . append (
ResBlock (
ch ,
time_embed_dim ,
dropout ,
out_channels = out_ch ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
up = True ,
)
if resblock_updown
else Upsample ( ch , conv_resample , dims = dims , out_channels = out_ch )
)
ds / / = 2
self . output_blocks . append ( TimestepEmbedSequential ( * layers ) )
self . _feature_size + = ch
self . out = nn . Sequential (
normalization ( ch ) ,
nn . SiLU ( ) ,
zero_module ( conv_nd ( dims , model_channels , out_channels , 3 , padding = 1 ) ) ,
)
def convert_to_fp16 ( self ) :
"""
Convert the torso of the model to float16 .
"""
self . input_blocks . apply ( convert_module_to_f16 )
self . middle_block . apply ( convert_module_to_f16 )
self . output_blocks . apply ( convert_module_to_f16 )
def convert_to_fp32 ( self ) :
"""
Convert the torso of the model to float32 .
"""
self . input_blocks . apply ( convert_module_to_f32 )
self . middle_block . apply ( convert_module_to_f32 )
self . output_blocks . apply ( convert_module_to_f32 )
def forward ( self , x , timesteps , y = None ) :
"""
Apply the model to an input batch .
: param x : an [ N x C x . . . ] Tensor of inputs .
: param timesteps : a 1 - D batch of timesteps .
: param y : an [ N ] Tensor of labels , if class - conditional .
: return : an [ N x C x . . . ] Tensor of outputs .
"""
hs = [ ]
emb = self . time_embed ( timestep_embedding ( timesteps , self . model_channels ) )
if self . num_classes is not None :
assert y . shape == ( x . shape [ 0 ] , )
emb = emb + self . label_emb ( y )
2021-07-26 22:27:31 +00:00
if self . use_raw_y_as_embedding :
emb = emb + y
2021-06-06 19:57:22 +00:00
h = x . type ( self . dtype )
for module in self . input_blocks :
h = module ( h , emb )
hs . append ( h )
h = self . middle_block ( h , emb )
for module in self . output_blocks :
h = th . cat ( [ h , hs . pop ( ) ] , dim = 1 )
h = module ( h , emb )
h = h . type ( x . dtype )
return self . out ( h )
class SuperResModel ( UNetModel ) :
"""
A UNetModel that performs super - resolution .
Expects an extra kwarg ` low_res ` to condition on a low - resolution image .
"""
2021-06-07 15:13:54 +00:00
def __init__ ( self , image_size , in_channels , num_corruptions = 0 , * args , * * kwargs ) :
2021-06-30 15:44:46 +00:00
self . num_corruptions = num_corruptions
2021-06-07 15:13:54 +00:00
super ( ) . __init__ ( image_size , in_channels * 2 + num_corruptions , * args , * * kwargs )
2021-06-06 19:57:22 +00:00
2021-06-07 15:13:54 +00:00
def forward ( self , x , timesteps , low_res = None , corruption_factor = None , * * kwargs ) :
b , _ , new_height , new_width = x . shape
2021-06-06 19:57:22 +00:00
upsampled = F . interpolate ( low_res , ( new_height , new_width ) , mode = " bilinear " )
2021-06-07 15:13:54 +00:00
if corruption_factor is not None :
2021-06-30 15:44:46 +00:00
if corruption_factor . shape [ 1 ] != self . num_corruptions :
if not hasattr ( self , ' _corruption_factor_warning ' ) :
print ( f " Warning! Dataloader gave us { corruption_factor . shape [ 1 ] } dim but we are only processing { self . num_corruptions } . The last n corruptions will be truncated. " )
self . _corruption_factor_warning = True
corruption_factor = corruption_factor [ : , : self . num_corruptions ]
2021-06-07 15:13:54 +00:00
corruption_factor = corruption_factor . view ( b , - 1 , 1 , 1 ) . repeat ( 1 , 1 , new_height , new_width )
else :
corruption_factor = torch . zeros ( ( b , self . num_corruptions , new_height , new_width ) , dtype = torch . float , device = x . device )
upsampled = torch . cat ( [ upsampled , corruption_factor ] , dim = 1 )
2021-06-06 19:57:22 +00:00
x = th . cat ( [ x , upsampled ] , dim = 1 )
res = super ( ) . forward ( x , timesteps , * * kwargs )
return res
class EncoderUNetModel ( nn . Module ) :
"""
The half UNet model with attention and timestep embedding .
For usage , see UNet .
"""
def __init__ (
self ,
image_size ,
in_channels ,
model_channels ,
out_channels ,
num_res_blocks ,
attention_resolutions ,
dropout = 0 ,
channel_mult = ( 1 , 2 , 4 , 8 ) ,
conv_resample = True ,
dims = 2 ,
use_fp16 = False ,
num_heads = 1 ,
num_head_channels = - 1 ,
num_heads_upsample = - 1 ,
use_scale_shift_norm = False ,
resblock_updown = False ,
use_new_attention_order = False ,
pool = " adaptive " ,
) :
super ( ) . __init__ ( )
if num_heads_upsample == - 1 :
num_heads_upsample = num_heads
self . in_channels = in_channels
self . model_channels = model_channels
self . out_channels = out_channels
self . num_res_blocks = num_res_blocks
self . attention_resolutions = attention_resolutions
self . dropout = dropout
self . channel_mult = channel_mult
self . conv_resample = conv_resample
self . dtype = th . float16 if use_fp16 else th . float32
self . num_heads = num_heads
self . num_head_channels = num_head_channels
self . num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4
self . time_embed = nn . Sequential (
linear ( model_channels , time_embed_dim ) ,
nn . SiLU ( ) ,
linear ( time_embed_dim , time_embed_dim ) ,
)
self . input_blocks = nn . ModuleList (
[
TimestepEmbedSequential (
conv_nd ( dims , in_channels , model_channels , 3 , padding = 1 )
)
]
)
self . _feature_size = model_channels
input_block_chans = [ model_channels ]
ch = model_channels
ds = 1
for level , mult in enumerate ( channel_mult ) :
for _ in range ( num_res_blocks ) :
layers = [
ResBlock (
ch ,
time_embed_dim ,
dropout ,
out_channels = mult * model_channels ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
)
]
ch = mult * model_channels
if ds in attention_resolutions :
layers . append (
AttentionBlock (
ch ,
num_heads = num_heads ,
num_head_channels = num_head_channels ,
use_new_attention_order = use_new_attention_order ,
)
)
self . input_blocks . append ( TimestepEmbedSequential ( * layers ) )
self . _feature_size + = ch
input_block_chans . append ( ch )
if level != len ( channel_mult ) - 1 :
out_ch = ch
self . input_blocks . append (
TimestepEmbedSequential (
ResBlock (
ch ,
time_embed_dim ,
dropout ,
out_channels = out_ch ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
down = True ,
)
if resblock_updown
else Downsample (
ch , conv_resample , dims = dims , out_channels = out_ch
)
)
)
ch = out_ch
input_block_chans . append ( ch )
ds * = 2
self . _feature_size + = ch
self . middle_block = TimestepEmbedSequential (
ResBlock (
ch ,
time_embed_dim ,
dropout ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
) ,
AttentionBlock (
ch ,
num_heads = num_heads ,
num_head_channels = num_head_channels ,
use_new_attention_order = use_new_attention_order ,
) ,
ResBlock (
ch ,
time_embed_dim ,
dropout ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
) ,
)
self . _feature_size + = ch
self . pool = pool
if pool == " adaptive " :
self . out = nn . Sequential (
normalization ( ch ) ,
nn . SiLU ( ) ,
nn . AdaptiveAvgPool2d ( ( 1 , 1 ) ) ,
zero_module ( conv_nd ( dims , ch , out_channels , 1 ) ) ,
nn . Flatten ( ) ,
)
elif pool == " attention " :
assert num_head_channels != - 1
self . out = nn . Sequential (
normalization ( ch ) ,
nn . SiLU ( ) ,
AttentionPool2d (
( image_size / / ds ) , ch , num_head_channels , out_channels
) ,
)
elif pool == " spatial " :
self . out = nn . Sequential (
nn . Linear ( self . _feature_size , 2048 ) ,
nn . ReLU ( ) ,
nn . Linear ( 2048 , self . out_channels ) ,
)
elif pool == " spatial_v2 " :
self . out = nn . Sequential (
nn . Linear ( self . _feature_size , 2048 ) ,
normalization ( 2048 ) ,
nn . SiLU ( ) ,
nn . Linear ( 2048 , self . out_channels ) ,
)
else :
raise NotImplementedError ( f " Unexpected { pool } pooling " )
def convert_to_fp16 ( self ) :
"""
Convert the torso of the model to float16 .
"""
self . input_blocks . apply ( convert_module_to_f16 )
self . middle_block . apply ( convert_module_to_f16 )
def convert_to_fp32 ( self ) :
"""
Convert the torso of the model to float32 .
"""
self . input_blocks . apply ( convert_module_to_f32 )
self . middle_block . apply ( convert_module_to_f32 )
def forward ( self , x , timesteps ) :
"""
Apply the model to an input batch .
: param x : an [ N x C x . . . ] Tensor of inputs .
: param timesteps : a 1 - D batch of timesteps .
: return : an [ N x K ] Tensor of outputs .
"""
emb = self . time_embed ( timestep_embedding ( timesteps , self . model_channels ) )
results = [ ]
h = x . type ( self . dtype )
for module in self . input_blocks :
h = module ( h , emb )
if self . pool . startswith ( " spatial " ) :
results . append ( h . type ( x . dtype ) . mean ( dim = ( 2 , 3 ) ) )
h = self . middle_block ( h , emb )
if self . pool . startswith ( " spatial " ) :
results . append ( h . type ( x . dtype ) . mean ( dim = ( 2 , 3 ) ) )
h = th . cat ( results , axis = - 1 )
return self . out ( h )
else :
h = h . type ( x . dtype )
return self . out ( h )
@register_model
def register_unet_diffusion ( opt_net , opt ) :
return SuperResModel ( * * opt_net [ ' args ' ] )
if __name__ == ' __main__ ' :
attention_ds = [ ]
for res in " 16,8 " . split ( " , " ) :
attention_ds . append ( 128 / / int ( res ) )
srm = SuperResModel ( image_size = 128 , in_channels = 3 , model_channels = 64 , out_channels = 3 , num_res_blocks = 1 , attention_resolutions = attention_ds , num_heads = 4 ,
num_heads_upsample = - 1 , use_scale_shift_norm = True )
x = torch . randn ( 1 , 3 , 128 , 128 )
l = torch . randn ( 1 , 3 , 32 , 32 )
ts = torch . LongTensor ( [ 555 ] )
y = srm ( x , ts , low_res = l )
2021-08-31 20:38:33 +00:00
print ( y . shape , y . mean ( ) , y . std ( ) , y . min ( ) , y . max ( ) )