2021-10-14 03:23:18 +00:00
from models . diffusion . fp16_util import convert_module_to_f32 , convert_module_to_f16
from models . diffusion . nn import timestep_embedding , normalization , zero_module , conv_nd , linear
from models . diffusion . unet_diffusion import AttentionPool2d , AttentionBlock , ResBlock , TimestepEmbedSequential , \
Downsample , Upsample
import torch
import torch . nn as nn
from models . gpt_voice . mini_encoder import AudioMiniEncoder , EmbeddingCombiner
from trainer . networks import register_model
from utils . util import get_mask_from_lengths
class DiscreteSpectrogramConditioningBlock ( nn . Module ) :
2021-10-16 15:02:01 +00:00
def __init__ ( self , dvae_channels , channels ) :
2021-10-14 03:23:18 +00:00
super ( ) . __init__ ( )
2021-10-17 23:32:46 +00:00
self . intg = nn . Sequential ( nn . Conv1d ( dvae_channels , channels , kernel_size = 1 ) ,
2021-10-15 18:10:11 +00:00
normalization ( channels ) ,
nn . SiLU ( ) ,
2021-10-17 23:32:46 +00:00
nn . Conv1d ( channels , channels , kernel_size = 3 ) )
2021-10-14 03:23:18 +00:00
"""
2021-10-15 17:51:17 +00:00
Embeds the given codes and concatenates them onto x . Return shape is the same as x . shape .
2021-10-14 03:23:18 +00:00
: param x : bxcxS waveform latent
: param codes : bxN discrete codes , N < = S
"""
2021-10-16 15:02:01 +00:00
def forward ( self , x , dvae_in ) :
b , c , S = x . shape
_ , q , N = dvae_in . shape
2021-10-17 23:32:46 +00:00
emb = self . intg ( dvae_in )
2021-10-14 03:23:18 +00:00
emb = nn . functional . interpolate ( emb , size = ( S , ) , mode = ' nearest ' )
2021-10-17 23:32:46 +00:00
return torch . cat ( [ x , emb ] , dim = 1 )
2021-10-14 03:23:18 +00:00
class DiffusionVocoderWithRef ( nn . Module ) :
"""
The full UNet model with attention and timestep embedding .
Customized to be conditioned on a spectrogram prior .
: param in_channels : channels in the input Tensor .
: param spectrogram_channels : channels in the conditioning spectrogram .
: param model_channels : base channel count for the model .
: param out_channels : channels in the output Tensor .
: param num_res_blocks : number of residual blocks per downsample .
: param attention_resolutions : a collection of downsample rates at which
attention will take place . May be a set , list , or tuple .
For example , if this contains 4 , then at 4 x downsampling , attention
will be used .
: param dropout : the dropout probability .
: param channel_mult : channel multiplier for each level of the UNet .
: param conv_resample : if True , use learned convolutions for upsampling and
downsampling .
: param dims : determines if the signal is 1 D , 2 D , or 3 D .
: param num_heads : the number of attention heads in each attention layer .
: param num_heads_channels : if specified , ignore num_heads and instead use
a fixed channel width per attention head .
: param num_heads_upsample : works with num_heads to set a different number
of heads for upsampling . Deprecated .
: param use_scale_shift_norm : use a FiLM - like conditioning mechanism .
: param resblock_updown : use residual blocks for up / downsampling .
: param use_new_attention_order : use a different attention pattern for potentially
increased efficiency .
"""
def __init__ (
self ,
model_channels ,
in_channels = 1 ,
out_channels = 2 , # mean and variance
2021-10-16 15:02:01 +00:00
discrete_codes = 512 ,
2021-10-14 03:23:18 +00:00
dropout = 0 ,
2021-10-15 17:51:17 +00:00
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
channel_mult = ( 1 , 1.5 , 2 , 3 , 4 , 6 , 8 , 12 , 16 , 24 , 32 , 48 ) ,
2021-10-17 23:32:46 +00:00
num_res_blocks = ( 1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ) ,
2021-10-15 17:51:17 +00:00
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
2021-10-17 23:32:46 +00:00
spectrogram_conditioning_resolutions = ( 512 , ) ,
2021-10-15 17:51:17 +00:00
attention_resolutions = ( 512 , 1024 , 2048 ) ,
2021-10-14 03:23:18 +00:00
conv_resample = True ,
dims = 1 ,
use_fp16 = False ,
num_heads = 1 ,
num_head_channels = - 1 ,
num_heads_upsample = - 1 ,
use_scale_shift_norm = False ,
resblock_updown = False ,
use_new_attention_order = False ,
kernel_size = 3 ,
scale_factor = 2 ,
conditioning_inputs_provided = True ,
conditioning_input_dim = 80 ,
2021-10-26 14:55:55 +00:00
time_embed_dim_multiplier = 4 ,
2021-10-14 03:23:18 +00:00
) :
super ( ) . __init__ ( )
if num_heads_upsample == - 1 :
num_heads_upsample = num_heads
self . in_channels = in_channels
self . model_channels = model_channels
self . out_channels = out_channels
self . attention_resolutions = attention_resolutions
self . dropout = dropout
self . channel_mult = channel_mult
self . conv_resample = conv_resample
self . dtype = torch . float16 if use_fp16 else torch . float32
self . num_heads = num_heads
self . num_head_channels = num_head_channels
self . num_heads_upsample = num_heads_upsample
self . dims = dims
padding = 1 if kernel_size == 3 else 2
2021-10-26 14:55:55 +00:00
time_embed_dim = model_channels * time_embed_dim_multiplier
2021-10-14 03:23:18 +00:00
self . time_embed = nn . Sequential (
linear ( model_channels , time_embed_dim ) ,
nn . SiLU ( ) ,
linear ( time_embed_dim , time_embed_dim ) ,
)
self . conditioning_enabled = conditioning_inputs_provided
if conditioning_inputs_provided :
self . contextual_embedder = AudioMiniEncoder ( conditioning_input_dim , time_embed_dim )
self . input_blocks = nn . ModuleList (
[
TimestepEmbedSequential (
conv_nd ( dims , in_channels , model_channels , kernel_size , padding = padding )
)
]
)
self . _feature_size = model_channels
input_block_chans = [ model_channels ]
ch = model_channels
ds = 1
2021-10-14 17:26:04 +00:00
for level , ( mult , num_blocks ) in enumerate ( zip ( channel_mult , num_res_blocks ) ) :
2021-10-14 03:23:18 +00:00
if ds in spectrogram_conditioning_resolutions :
self . input_blocks . append ( DiscreteSpectrogramConditioningBlock ( discrete_codes , ch ) )
2021-10-17 23:32:46 +00:00
ch * = 2
2021-10-14 03:23:18 +00:00
2021-10-14 17:26:04 +00:00
for _ in range ( num_blocks ) :
2021-10-14 03:23:18 +00:00
layers = [
ResBlock (
ch ,
time_embed_dim ,
dropout ,
2021-10-15 17:51:17 +00:00
out_channels = int ( mult * model_channels ) ,
2021-10-14 03:23:18 +00:00
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
kernel_size = kernel_size ,
)
]
2021-10-15 17:51:17 +00:00
ch = int ( mult * model_channels )
2021-10-14 03:23:18 +00:00
if ds in attention_resolutions :
layers . append (
AttentionBlock (
ch ,
num_heads = num_heads ,
num_head_channels = num_head_channels ,
use_new_attention_order = use_new_attention_order ,
)
)
self . input_blocks . append ( TimestepEmbedSequential ( * layers ) )
self . _feature_size + = ch
input_block_chans . append ( ch )
if level != len ( channel_mult ) - 1 :
out_ch = ch
self . input_blocks . append (
TimestepEmbedSequential (
ResBlock (
ch ,
time_embed_dim ,
dropout ,
out_channels = out_ch ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
down = True ,
kernel_size = kernel_size ,
)
if resblock_updown
else Downsample (
ch , conv_resample , dims = dims , out_channels = out_ch , factor = scale_factor
)
)
)
ch = out_ch
input_block_chans . append ( ch )
ds * = 2
self . _feature_size + = ch
self . middle_block = TimestepEmbedSequential (
ResBlock (
ch ,
time_embed_dim ,
dropout ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
kernel_size = kernel_size ,
) ,
AttentionBlock (
ch ,
num_heads = num_heads ,
num_head_channels = num_head_channels ,
use_new_attention_order = use_new_attention_order ,
) ,
ResBlock (
ch ,
time_embed_dim ,
dropout ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
kernel_size = kernel_size ,
) ,
)
self . _feature_size + = ch
self . output_blocks = nn . ModuleList ( [ ] )
2021-10-14 17:26:04 +00:00
for level , ( mult , num_blocks ) in list ( enumerate ( zip ( channel_mult , num_res_blocks ) ) ) [ : : - 1 ] :
for i in range ( num_blocks + 1 ) :
2021-10-14 03:23:18 +00:00
ich = input_block_chans . pop ( )
layers = [
ResBlock (
ch + ich ,
time_embed_dim ,
dropout ,
2021-10-15 17:51:17 +00:00
out_channels = int ( model_channels * mult ) ,
2021-10-14 03:23:18 +00:00
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
kernel_size = kernel_size ,
)
]
2021-10-15 17:51:17 +00:00
ch = int ( model_channels * mult )
2021-10-14 03:23:18 +00:00
if ds in attention_resolutions :
layers . append (
AttentionBlock (
ch ,
num_heads = num_heads_upsample ,
num_head_channels = num_head_channels ,
use_new_attention_order = use_new_attention_order ,
)
)
2021-10-14 17:26:04 +00:00
if level and i == num_blocks :
2021-10-14 03:23:18 +00:00
out_ch = ch
layers . append (
ResBlock (
ch ,
time_embed_dim ,
dropout ,
out_channels = out_ch ,
dims = dims ,
use_scale_shift_norm = use_scale_shift_norm ,
up = True ,
kernel_size = kernel_size ,
)
if resblock_updown
else Upsample ( ch , conv_resample , dims = dims , out_channels = out_ch , factor = scale_factor )
)
ds / / = 2
self . output_blocks . append ( TimestepEmbedSequential ( * layers ) )
self . _feature_size + = ch
self . out = nn . Sequential (
normalization ( ch ) ,
nn . SiLU ( ) ,
zero_module ( conv_nd ( dims , model_channels , out_channels , kernel_size , padding = padding ) ) ,
)
def convert_to_fp16 ( self ) :
"""
Convert the torso of the model to float16 .
"""
self . input_blocks . apply ( convert_module_to_f16 )
self . middle_block . apply ( convert_module_to_f16 )
self . output_blocks . apply ( convert_module_to_f16 )
def convert_to_fp32 ( self ) :
"""
Convert the torso of the model to float32 .
"""
self . input_blocks . apply ( convert_module_to_f32 )
self . middle_block . apply ( convert_module_to_f32 )
self . output_blocks . apply ( convert_module_to_f32 )
2021-10-21 03:19:38 +00:00
def forward ( self , x , timesteps , spectrogram , conditioning_inputs = None , num_conditioning_signals = None ) :
2021-10-14 03:23:18 +00:00
"""
Apply the model to an input batch .
: param x : an [ N x C x . . . ] Tensor of inputs .
: param timesteps : a 1 - D batch of timesteps .
: param y : an [ N ] Tensor of labels , if class - conditional .
: return : an [ N x C x . . . ] Tensor of outputs .
"""
assert x . shape [ - 1 ] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
if self . conditioning_enabled :
assert conditioning_inputs is not None
assert num_conditioning_signals is not None
hs = [ ]
emb1 = self . time_embed ( timestep_embedding ( timesteps , self . model_channels ) )
if self . conditioning_enabled :
2021-10-26 14:54:30 +00:00
#emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1)
emb2 = self . contextual_embedder ( conditioning_inputs [ : , 0 ] )
2021-10-24 15:08:58 +00:00
emb = emb1 + emb2
2021-10-14 03:23:18 +00:00
else :
emb = emb1
h = x . type ( self . dtype )
for k , module in enumerate ( self . input_blocks ) :
if isinstance ( module , DiscreteSpectrogramConditioningBlock ) :
2021-10-21 03:19:38 +00:00
h = module ( h , spectrogram )
2021-10-14 03:23:18 +00:00
else :
h = module ( h , emb )
hs . append ( h )
h = self . middle_block ( h , emb )
for module in self . output_blocks :
h = torch . cat ( [ h , hs . pop ( ) ] , dim = 1 )
h = module ( h , emb )
h = h . type ( x . dtype )
return self . out ( h )
@register_model
def register_unet_diffusion_vocoder_with_ref ( opt_net , opt ) :
return DiffusionVocoderWithRef ( * * opt_net [ ' kwargs ' ] )
# Test for ~4 second audio clip at 22050Hz
if __name__ == ' __main__ ' :
2021-10-15 17:51:17 +00:00
clip = torch . randn ( 2 , 1 , 40960 )
2021-10-16 15:02:01 +00:00
#spec = torch.randint(8192, (2, 40,))
2021-10-17 23:32:46 +00:00
spec = torch . randn ( 2 , 512 , 160 )
2021-10-15 17:51:17 +00:00
cond = torch . randn ( 2 , 3 , 80 , 173 )
2021-10-14 03:23:18 +00:00
ts = torch . LongTensor ( [ 555 , 556 ] )
2021-10-26 14:55:55 +00:00
model = DiffusionVocoderWithRef ( 32 , conditioning_inputs_provided = True , time_embed_dim_multiplier = 8 )
2021-10-24 15:08:58 +00:00
print ( model ( clip , ts , spec , cond , 3 ) . shape )