2022-03-22 17:52:46 +00:00
import math
2022-04-01 20:15:17 +00:00
import random
2022-03-22 17:52:46 +00:00
from abc import abstractmethod
import torch
import torch . nn as nn
import torch . nn . functional as F
from torch import autocast
2022-05-01 22:24:24 +00:00
from tortoise . models . arch_util import normalization , AttentionBlock
2022-03-22 17:52:46 +00:00
def is_latent ( t ) :
return t . dtype == torch . float
def is_sequence ( t ) :
return t . dtype == torch . long
def timestep_embedding ( timesteps , dim , max_period = 10000 ) :
"""
Create sinusoidal timestep embeddings .
: param timesteps : a 1 - D Tensor of N indices , one per batch element .
These may be fractional .
: param dim : the dimension of the output .
: param max_period : controls the minimum frequency of the embeddings .
: return : an [ N x dim ] Tensor of positional embeddings .
"""
half = dim / / 2
freqs = torch . exp (
- math . log ( max_period ) * torch . arange ( start = 0 , end = half , dtype = torch . float32 ) / half
) . to ( device = timesteps . device )
args = timesteps [ : , None ] . float ( ) * freqs [ None ]
embedding = torch . cat ( [ torch . cos ( args ) , torch . sin ( args ) ] , dim = - 1 )
if dim % 2 :
embedding = torch . cat ( [ embedding , torch . zeros_like ( embedding [ : , : 1 ] ) ] , dim = - 1 )
return embedding
class TimestepBlock ( nn . Module ) :
@abstractmethod
def forward ( self , x , emb ) :
"""
Apply the module to ` x ` given ` emb ` timestep embeddings .
"""
class TimestepEmbedSequential ( nn . Sequential , TimestepBlock ) :
def forward ( self , x , emb ) :
for layer in self :
if isinstance ( layer , TimestepBlock ) :
x = layer ( x , emb )
else :
x = layer ( x )
return x
class ResBlock ( TimestepBlock ) :
def __init__ (
self ,
channels ,
emb_channels ,
dropout ,
out_channels = None ,
2022-04-01 20:15:17 +00:00
dims = 2 ,
2022-03-22 17:52:46 +00:00
kernel_size = 3 ,
efficient_config = True ,
use_scale_shift_norm = False ,
) :
super ( ) . __init__ ( )
self . channels = channels
self . emb_channels = emb_channels
self . dropout = dropout
self . out_channels = out_channels or channels
self . use_scale_shift_norm = use_scale_shift_norm
padding = { 1 : 0 , 3 : 1 , 5 : 2 } [ kernel_size ]
eff_kernel = 1 if efficient_config else 3
eff_padding = 0 if efficient_config else 1
self . in_layers = nn . Sequential (
normalization ( channels ) ,
nn . SiLU ( ) ,
nn . Conv1d ( channels , self . out_channels , eff_kernel , padding = eff_padding ) ,
)
self . emb_layers = nn . Sequential (
nn . SiLU ( ) ,
2022-04-01 20:15:17 +00:00
nn . Linear (
2022-03-22 17:52:46 +00:00
emb_channels ,
2 * self . out_channels if use_scale_shift_norm else self . out_channels ,
) ,
)
self . out_layers = nn . Sequential (
normalization ( self . out_channels ) ,
nn . SiLU ( ) ,
nn . Dropout ( p = dropout ) ,
2022-04-01 20:15:17 +00:00
nn . Conv1d ( self . out_channels , self . out_channels , kernel_size , padding = padding ) ,
2022-03-22 17:52:46 +00:00
)
if self . out_channels == channels :
self . skip_connection = nn . Identity ( )
else :
self . skip_connection = nn . Conv1d ( channels , self . out_channels , eff_kernel , padding = eff_padding )
def forward ( self , x , emb ) :
h = self . in_layers ( x )
emb_out = self . emb_layers ( emb ) . type ( h . dtype )
while len ( emb_out . shape ) < len ( h . shape ) :
emb_out = emb_out [ . . . , None ]
if self . use_scale_shift_norm :
out_norm , out_rest = self . out_layers [ 0 ] , self . out_layers [ 1 : ]
scale , shift = torch . chunk ( emb_out , 2 , dim = 1 )
h = out_norm ( h ) * ( 1 + scale ) + shift
h = out_rest ( h )
else :
h = h + emb_out
h = self . out_layers ( h )
return self . skip_connection ( x ) + h
2022-04-01 20:15:17 +00:00
class DiffusionLayer ( TimestepBlock ) :
def __init__ ( self , model_channels , dropout , num_heads ) :
super ( ) . __init__ ( )
self . resblk = ResBlock ( model_channels , model_channels , dropout , model_channels , dims = 1 , use_scale_shift_norm = True )
self . attn = AttentionBlock ( model_channels , num_heads , relative_pos_embeddings = True )
2022-03-22 17:52:46 +00:00
2022-04-01 20:15:17 +00:00
def forward ( self , x , time_emb ) :
y = self . resblk ( x , time_emb )
return self . attn ( y )
class DiffusionTts ( nn . Module ) :
2022-03-22 17:52:46 +00:00
def __init__ (
self ,
2022-04-01 20:15:17 +00:00
model_channels = 512 ,
num_layers = 8 ,
in_channels = 100 ,
in_latent_channels = 512 ,
2022-03-22 17:52:46 +00:00
in_tokens = 8193 ,
2022-04-01 20:15:17 +00:00
out_channels = 200 , # mean and variance
2022-03-22 17:52:46 +00:00
dropout = 0 ,
2023-02-06 03:44:34 +00:00
use_fp16 = True ,
2022-04-01 20:15:17 +00:00
num_heads = 16 ,
2022-03-22 17:52:46 +00:00
# Parameters for regularization.
2022-04-01 20:15:17 +00:00
layer_drop = .1 ,
2022-03-22 17:52:46 +00:00
unconditioned_percentage = .1 , # This implements a mechanism similar to what is used in classifier-free training.
) :
super ( ) . __init__ ( )
self . in_channels = in_channels
self . model_channels = model_channels
self . out_channels = out_channels
self . dropout = dropout
self . num_heads = num_heads
self . unconditioned_percentage = unconditioned_percentage
self . enable_fp16 = use_fp16
2022-04-01 20:15:17 +00:00
self . layer_drop = layer_drop
2022-03-22 17:52:46 +00:00
2022-04-01 20:15:17 +00:00
self . inp_block = nn . Conv1d ( in_channels , model_channels , 3 , 1 , 1 )
2022-03-22 17:52:46 +00:00
self . time_embed = nn . Sequential (
2022-04-01 20:15:17 +00:00
nn . Linear ( model_channels , model_channels ) ,
2022-03-22 17:52:46 +00:00
nn . SiLU ( ) ,
2022-04-01 20:15:17 +00:00
nn . Linear ( model_channels , model_channels ) ,
2022-03-22 17:52:46 +00:00
)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network.
2022-04-01 20:15:17 +00:00
self . code_embedding = nn . Embedding ( in_tokens , model_channels )
2022-03-22 17:52:46 +00:00
self . code_converter = nn . Sequential (
2022-04-01 20:15:17 +00:00
AttentionBlock ( model_channels , num_heads , relative_pos_embeddings = True ) ,
AttentionBlock ( model_channels , num_heads , relative_pos_embeddings = True ) ,
AttentionBlock ( model_channels , num_heads , relative_pos_embeddings = True ) ,
2022-03-22 17:52:46 +00:00
)
2022-04-01 20:15:17 +00:00
self . code_norm = normalization ( model_channels )
2022-04-13 02:53:09 +00:00
self . latent_conditioner = nn . Sequential (
nn . Conv1d ( in_latent_channels , model_channels , 3 , padding = 1 ) ,
AttentionBlock ( model_channels , num_heads , relative_pos_embeddings = True ) ,
AttentionBlock ( model_channels , num_heads , relative_pos_embeddings = True ) ,
AttentionBlock ( model_channels , num_heads , relative_pos_embeddings = True ) ,
AttentionBlock ( model_channels , num_heads , relative_pos_embeddings = True ) ,
)
2022-04-01 20:15:17 +00:00
self . contextual_embedder = nn . Sequential ( nn . Conv1d ( in_channels , model_channels , 3 , padding = 1 , stride = 2 ) ,
nn . Conv1d ( model_channels , model_channels * 2 , 3 , padding = 1 , stride = 2 ) ,
AttentionBlock ( model_channels * 2 , num_heads , relative_pos_embeddings = True , do_checkpoint = False ) ,
AttentionBlock ( model_channels * 2 , num_heads , relative_pos_embeddings = True , do_checkpoint = False ) ,
AttentionBlock ( model_channels * 2 , num_heads , relative_pos_embeddings = True , do_checkpoint = False ) ,
AttentionBlock ( model_channels * 2 , num_heads , relative_pos_embeddings = True , do_checkpoint = False ) ,
AttentionBlock ( model_channels * 2 , num_heads , relative_pos_embeddings = True , do_checkpoint = False ) )
self . unconditioned_embedding = nn . Parameter ( torch . randn ( 1 , model_channels , 1 ) )
self . conditioning_timestep_integrator = TimestepEmbedSequential (
DiffusionLayer ( model_channels , dropout , num_heads ) ,
DiffusionLayer ( model_channels , dropout , num_heads ) ,
DiffusionLayer ( model_channels , dropout , num_heads ) ,
2022-03-22 17:52:46 +00:00
)
2022-04-13 02:53:09 +00:00
2022-04-01 20:15:17 +00:00
self . integrating_conv = nn . Conv1d ( model_channels * 2 , model_channels , kernel_size = 1 )
self . mel_head = nn . Conv1d ( model_channels , in_channels , kernel_size = 3 , padding = 1 )
self . layers = nn . ModuleList ( [ DiffusionLayer ( model_channels , dropout , num_heads ) for _ in range ( num_layers ) ] +
[ ResBlock ( model_channels , model_channels , dropout , dims = 1 , use_scale_shift_norm = True ) for _ in range ( 3 ) ] )
2022-03-22 17:52:46 +00:00
self . out = nn . Sequential (
2022-04-01 20:15:17 +00:00
normalization ( model_channels ) ,
2022-03-22 17:52:46 +00:00
nn . SiLU ( ) ,
2022-04-01 20:15:17 +00:00
nn . Conv1d ( model_channels , out_channels , 3 , padding = 1 ) ,
2022-03-22 17:52:46 +00:00
)
2022-04-01 20:15:17 +00:00
def get_grad_norm_parameter_groups ( self ) :
groups = {
' minicoder ' : list ( self . contextual_embedder . parameters ( ) ) ,
' layers ' : list ( self . layers . parameters ( ) ) ,
2022-04-13 02:53:09 +00:00
' code_converters ' : list ( self . code_embedding . parameters ( ) ) + list ( self . code_converter . parameters ( ) ) + list ( self . latent_conditioner . parameters ( ) ) + list ( self . latent_conditioner . parameters ( ) ) ,
2022-04-01 20:15:17 +00:00
' timestep_integrator ' : list ( self . conditioning_timestep_integrator . parameters ( ) ) + list ( self . integrating_conv . parameters ( ) ) ,
' time_embed ' : list ( self . time_embed . parameters ( ) ) ,
}
return groups
2022-05-01 23:25:18 +00:00
def get_conditioning ( self , conditioning_input ) :
2022-04-01 20:15:17 +00:00
speech_conditioning_input = conditioning_input . unsqueeze ( 1 ) if len (
conditioning_input . shape ) == 3 else conditioning_input
conds = [ ]
for j in range ( speech_conditioning_input . shape [ 1 ] ) :
conds . append ( self . contextual_embedder ( speech_conditioning_input [ : , j ] ) )
conds = torch . cat ( conds , dim = - 1 )
2022-05-02 21:40:03 +00:00
conds = conds . mean ( dim = - 1 )
2022-05-01 23:25:18 +00:00
return conds
def timestep_independent ( self , aligned_conditioning , conditioning_latent , expected_seq_len , return_code_pred ) :
# Shuffle aligned_latent to BxCxS format
if is_latent ( aligned_conditioning ) :
aligned_conditioning = aligned_conditioning . permute ( 0 , 2 , 1 )
2022-05-02 21:40:03 +00:00
cond_scale , cond_shift = torch . chunk ( conditioning_latent , 2 , dim = 1 )
2022-04-01 20:15:17 +00:00
if is_latent ( aligned_conditioning ) :
2022-04-13 02:53:09 +00:00
code_emb = self . latent_conditioner ( aligned_conditioning )
2022-04-01 20:15:17 +00:00
else :
code_emb = self . code_embedding ( aligned_conditioning ) . permute ( 0 , 2 , 1 )
code_emb = self . code_converter ( code_emb )
code_emb = self . code_norm ( code_emb ) * ( 1 + cond_scale . unsqueeze ( - 1 ) ) + cond_shift . unsqueeze ( - 1 )
unconditioned_batches = torch . zeros ( ( code_emb . shape [ 0 ] , 1 , 1 ) , device = code_emb . device )
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self . training and self . unconditioned_percentage > 0 :
unconditioned_batches = torch . rand ( ( code_emb . shape [ 0 ] , 1 , 1 ) ,
device = code_emb . device ) < self . unconditioned_percentage
code_emb = torch . where ( unconditioned_batches , self . unconditioned_embedding . repeat ( aligned_conditioning . shape [ 0 ] , 1 , 1 ) ,
code_emb )
2022-04-04 22:51:35 +00:00
expanded_code_emb = F . interpolate ( code_emb , size = expected_seq_len , mode = ' nearest ' )
2022-04-01 20:15:17 +00:00
if not return_code_pred :
return expanded_code_emb
else :
mel_pred = self . mel_head ( expanded_code_emb )
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
mel_pred = mel_pred * unconditioned_batches . logical_not ( )
return expanded_code_emb , mel_pred
2022-03-22 17:52:46 +00:00
2022-05-01 23:25:18 +00:00
def forward ( self , x , timesteps , aligned_conditioning = None , conditioning_latent = None , precomputed_aligned_embeddings = None , conditioning_free = False , return_code_pred = False ) :
2022-04-01 20:15:17 +00:00
"""
Apply the model to an input batch .
: param x : an [ N x C x . . . ] Tensor of inputs .
: param timesteps : a 1 - D batch of timesteps .
: param aligned_conditioning : an aligned latent or sequence of tokens providing useful data about the sample to be produced .
2022-05-01 23:25:18 +00:00
: param conditioning_latent : a pre - computed conditioning latent ; see get_conditioning ( ) .
2022-04-01 20:15:17 +00:00
: param precomputed_aligned_embeddings : Embeddings returned from self . timestep_independent ( )
: param conditioning_free : When set , all conditioning inputs ( including tokens and conditioning_input ) will not be considered .
: return : an [ N x C x . . . ] Tensor of outputs .
"""
2022-05-01 23:25:18 +00:00
assert precomputed_aligned_embeddings is not None or ( aligned_conditioning is not None and conditioning_latent is not None )
2022-04-01 20:15:17 +00:00
assert not ( return_code_pred and precomputed_aligned_embeddings is not None ) # These two are mutually exclusive.
unused_params = [ ]
if conditioning_free :
code_emb = self . unconditioned_embedding . repeat ( x . shape [ 0 ] , 1 , x . shape [ - 1 ] )
unused_params . extend ( list ( self . code_converter . parameters ( ) ) + list ( self . code_embedding . parameters ( ) ) )
2022-04-13 02:53:09 +00:00
unused_params . extend ( list ( self . latent_conditioner . parameters ( ) ) )
2022-04-01 20:15:17 +00:00
else :
if precomputed_aligned_embeddings is not None :
2022-03-27 03:32:12 +00:00
code_emb = precomputed_aligned_embeddings
2022-04-01 20:15:17 +00:00
else :
2022-05-01 23:25:18 +00:00
code_emb , mel_pred = self . timestep_independent ( aligned_conditioning , conditioning_latent , x . shape [ - 1 ] , True )
2022-04-01 20:15:17 +00:00
if is_latent ( aligned_conditioning ) :
unused_params . extend ( list ( self . code_converter . parameters ( ) ) + list ( self . code_embedding . parameters ( ) ) )
2022-03-22 17:52:46 +00:00
else :
2022-04-13 02:53:09 +00:00
unused_params . extend ( list ( self . latent_conditioner . parameters ( ) ) )
2022-04-04 22:51:35 +00:00
2022-04-01 20:15:17 +00:00
unused_params . append ( self . unconditioned_embedding )
time_emb = self . time_embed ( timestep_embedding ( timesteps , self . model_channels ) )
code_emb = self . conditioning_timestep_integrator ( code_emb , time_emb )
x = self . inp_block ( x )
x = torch . cat ( [ x , code_emb ] , dim = 1 )
x = self . integrating_conv ( x )
for i , lyr in enumerate ( self . layers ) :
# Do layer drop where applicable. Do not drop first and last layers.
if self . training and self . layer_drop > 0 and i != 0 and i != ( len ( self . layers ) - 1 ) and random . random ( ) < self . layer_drop :
unused_params . extend ( list ( lyr . parameters ( ) ) )
else :
# First and last blocks will have autocast disabled for improved precision.
with autocast ( x . device . type , enabled = self . enable_fp16 and i != 0 ) :
x = lyr ( x , time_emb )
2022-03-22 17:52:46 +00:00
2022-04-01 20:15:17 +00:00
x = x . float ( )
out = self . out ( x )
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
extraneous_addition = 0
for p in unused_params :
extraneous_addition = extraneous_addition + p . mean ( )
out = out + extraneous_addition * 0
if return_code_pred :
return out , mel_pred
2022-03-27 03:32:12 +00:00
return out
2022-03-22 17:52:46 +00:00
if __name__ == ' __main__ ' :
2022-04-01 20:15:17 +00:00
clip = torch . randn ( 2 , 100 , 400 )
aligned_latent = torch . randn ( 2 , 388 , 512 )
aligned_sequence = torch . randint ( 0 , 8192 , ( 2 , 100 ) )
cond = torch . randn ( 2 , 100 , 400 )
2022-03-22 17:52:46 +00:00
ts = torch . LongTensor ( [ 600 , 600 ] )
2022-04-01 20:15:17 +00:00
model = DiffusionTts ( 512 , layer_drop = .3 , unconditioned_percentage = .5 )
2022-03-22 17:52:46 +00:00
# Test with latent aligned conditioning
2022-04-01 20:15:17 +00:00
#o = model(clip, ts, aligned_latent, cond)
2022-03-22 17:52:46 +00:00
# Test with sequence aligned conditioning
o = model ( clip , ts , aligned_sequence , cond )
2022-04-01 20:15:17 +00:00