2021-07-28 02:33:30 +00:00
import torch
import torch . nn as nn
import torch . nn . functional as F
2021-08-09 17:56:06 +00:00
from munch import munchify
2021-07-28 02:33:30 +00:00
2021-08-05 02:07:45 +00:00
from models . gpt_voice . lucidrains_gpt import Transformer
2021-08-05 13:07:17 +00:00
from models . gpt_voice . min_gpt import GPT , GPTConfig
2021-07-28 02:33:30 +00:00
from models . tacotron2 . taco_utils import get_mask_from_lengths
from models . tacotron2 . text import symbols
from trainer . networks import register_model
2021-08-07 04:08:51 +00:00
from utils . util import opt_get
2021-07-28 02:33:30 +00:00
class GptTts ( nn . Module ) :
2021-08-05 02:07:45 +00:00
MAX_SYMBOLS_PER_PHRASE = 200
NUMBER_SYMBOLS = len ( symbols )
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
2021-08-07 04:01:06 +00:00
MEL_DICTIONARY_SIZE = 512 + 3
2021-08-04 03:08:27 +00:00
MEL_START_TOKEN = MEL_DICTIONARY_SIZE - 3
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE - 2
2021-08-07 04:08:51 +00:00
def __init__ ( self , layers = 8 , model_dim = 512 , heads = 8 ) :
2021-07-28 02:33:30 +00:00
super ( ) . __init__ ( )
2021-08-07 04:01:06 +00:00
max_mel_frames = 900 * 1 / / 4 # 900 is the max number of MEL frames. The VQVAE outputs 1/8 of the input mel as tokens.
2021-07-28 02:33:30 +00:00
2021-07-31 05:07:35 +00:00
self . model_dim = model_dim
self . max_mel_frames = max_mel_frames
2021-08-05 02:07:45 +00:00
self . text_embedding = nn . Embedding ( self . NUMBER_TEXT_TOKENS , model_dim )
2021-08-04 03:08:27 +00:00
self . mel_embedding = nn . Embedding ( self . MEL_DICTIONARY_SIZE , model_dim )
2021-08-05 02:07:45 +00:00
self . text_pos_embedding = nn . Embedding ( self . MAX_SYMBOLS_PER_PHRASE , model_dim )
2021-08-04 03:08:27 +00:00
self . mel_pos_embedding = nn . Embedding ( max_mel_frames , model_dim )
2021-08-09 18:01:10 +00:00
self . gpt = Transformer ( dim = model_dim , depth = layers , seq_len = 1 + self . MAX_SYMBOLS_PER_PHRASE + max_mel_frames , heads = heads ,
attn_dropout = .1 , ff_dropout = .1 )
2021-08-04 03:08:27 +00:00
self . final_norm = nn . LayerNorm ( model_dim )
2021-08-05 02:07:45 +00:00
self . text_head = nn . Linear ( model_dim , self . NUMBER_TEXT_TOKENS )
2021-08-04 03:08:27 +00:00
self . mel_head = nn . Linear ( model_dim , self . MEL_DICTIONARY_SIZE )
def forward ( self , text_inputs , text_lengths , mel_targets , output_lengths ) :
2021-07-28 02:33:30 +00:00
text_emb = self . text_embedding ( text_inputs )
2021-08-04 03:08:27 +00:00
text_emb = text_emb + self . text_pos_embedding ( torch . arange ( text_inputs . shape [ 1 ] , device = text_inputs . device ) )
mel_emb = self . mel_embedding ( mel_targets )
mel_emb = mel_emb + self . mel_pos_embedding ( torch . arange ( mel_targets . shape [ 1 ] , device = mel_targets . device ) )
emb = torch . cat ( [ text_emb , mel_emb ] , dim = 1 )
enc = self . gpt ( emb )
# Compute logits for text and mel heads
text_logits = self . final_norm ( enc [ : , : text_emb . shape [ 1 ] ] )
mel_logits = self . final_norm ( enc [ : , text_emb . shape [ 1 ] : ] )
2021-08-05 02:07:45 +00:00
text_logits = self . text_head ( text_logits )
2021-08-04 03:08:27 +00:00
mel_logits = self . mel_head ( mel_logits )
# Compute loss
2021-08-04 16:28:09 +00:00
text_targets = text_inputs [ : , 1 : ]
text_logits = text_logits . permute ( 0 , 2 , 1 ) [ : , : , : - 1 ] # The last element of the logits is unneeded because the input to the transformer contains a <EOS> token for both text and mel.
loss_text = F . cross_entropy ( text_logits , text_targets , reduction = ' none ' )
mel_targets = mel_targets [ : , 1 : ]
mel_logits = mel_logits . permute ( 0 , 2 , 1 ) [ : , : , : - 1 ]
loss_mel = F . cross_entropy ( mel_logits , mel_targets , reduction = ' none ' )
2021-08-04 06:44:04 +00:00
2021-08-04 03:08:27 +00:00
# Fix up mel_logits so it can go into a VAE decoder as well.
2021-08-04 16:28:09 +00:00
mel_codes = torch . argmax ( F . softmax ( mel_logits , dim = 1 ) , dim = 1 )
2021-08-05 02:07:45 +00:00
mel_pad_mask = ~ get_mask_from_lengths ( output_lengths - 1 , mel_targets . shape [ 1 ] )
2021-08-04 16:28:09 +00:00
mel_codes = mel_codes * torch . ones_like ( mel_codes ) . masked_fill_ ( mel_pad_mask , 0 )
2021-08-05 02:07:45 +00:00
mel_codes = mel_codes [ : , : - 1 ] # Strip off <EOS> token too (or padding). The important part is that the output sequence length is identical to the VAE input.
2021-08-04 03:08:27 +00:00
extra_mask = mel_codes < self . MEL_DICTIONARY_SIZE - 3 # The VAE doesn't know about START/STOP/PAD
mel_codes = mel_codes * extra_mask
2021-08-05 02:07:45 +00:00
# This class also returns the mel_targets for validation purposes. Format those.
mel_targets = mel_targets [ : , : - 1 ]
mel_targets = mel_targets * ( mel_targets < self . MEL_DICTIONARY_SIZE - 3 )
return loss_text . mean ( ) , loss_mel . mean ( ) , mel_codes , mel_targets
2021-07-31 05:07:35 +00:00
2021-08-04 06:44:04 +00:00
def inference ( self , text_inputs ) :
2021-08-09 05:22:42 +00:00
b , s = text_inputs . shape
2021-07-31 05:07:35 +00:00
text_emb = self . text_embedding ( text_inputs )
2021-08-09 05:22:42 +00:00
text_emb = text_emb + self . text_pos_embedding ( torch . arange ( s , device = text_inputs . device ) )
2021-08-04 06:44:04 +00:00
2021-08-08 02:11:10 +00:00
mel_seq = torch . full ( ( b , 1 ) , fill_value = self . MEL_START_TOKEN , device = text_emb . device )
stop_encountered = torch . zeros ( ( b , ) , device = text_emb . device )
2021-08-05 12:46:30 +00:00
while not torch . all ( stop_encountered ) and len ( mel_seq ) < self . max_mel_frames :
mel_emb = self . mel_embedding ( mel_seq )
2021-08-04 16:28:09 +00:00
mel_emb = mel_emb + self . mel_pos_embedding ( torch . arange ( mel_emb . shape [ 1 ] , device = mel_emb . device ) )
2021-08-04 06:44:04 +00:00
emb = torch . cat ( [ text_emb , mel_emb ] , dim = 1 )
enc = self . gpt ( emb )
mel_logits = self . final_norm ( enc [ : , text_emb . shape [ 1 ] : ] )
mel_logits = self . mel_head ( mel_logits )
mel_codes = torch . argmax ( F . softmax ( mel_logits , dim = - 1 ) , dim = - 1 )
2021-08-05 12:46:30 +00:00
mel_seq = torch . cat ( [ mel_seq , mel_codes [ : , - 1 ] . unsqueeze ( 1 ) ] , dim = 1 )
stop_encountered = torch . logical_or ( stop_encountered , mel_seq [ : , - 1 ] == self . MEL_STOP_TOKEN )
2021-08-04 06:44:04 +00:00
if len ( mel_seq ) > = self . max_mel_frames :
print ( " Warning! Encountered frame limit before a stop token. Output is likely wrong. " )
2021-08-06 18:04:12 +00:00
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
2021-08-08 02:11:10 +00:00
mel_seq = mel_seq [ : , 1 : - 1 ] # Remove first and last tokens, which were artificially added for GPT
mel_seq = mel_seq * ( mel_seq < 512 ) # The DVAE doesn't understand BOS/EOS/PAD tokens.
return mel_seq
2021-07-31 05:07:35 +00:00
2021-08-09 17:56:06 +00:00
def inference_beam_topk ( self , text ) :
def topk_sampler ( distribution , k ) :
return torch . topk ( distribution , k = k , dim = - 1 )
return self . inference_beam ( text , topk_sampler )
def inference_beam_sampled ( self , text ) :
def multinomial_sampler ( distribution , k ) :
indices = torch . multinomial ( distribution , num_samples = k , replacement = False )
values = torch . gather ( distribution , dim = 1 , index = indices )
class container :
def __init__ ( self , i , v ) :
self . indices = i
self . values = v
return container ( indices , values )
return self . inference_beam ( text , multinomial_sampler )
def inference_beam ( self , text_inputs , sampler_fn ) :
2021-08-09 05:22:42 +00:00
beam_width = 16
2021-08-09 17:56:06 +00:00
temperature = .8
2021-08-09 05:22:42 +00:00
b , s = text_inputs . shape
assert b == 1 # Beam search only works on batches of one.
text_emb = self . text_embedding ( text_inputs )
text_emb = text_emb + self . text_pos_embedding ( torch . arange ( s , device = text_inputs . device ) )
mel_seq = torch . full ( ( b , 1 ) , fill_value = self . MEL_START_TOKEN , device = text_emb . device )
probabilities = torch . ones ( ( b , ) , device = text_emb . device )
while len ( mel_seq ) < self . max_mel_frames :
mel_emb = self . mel_embedding ( mel_seq )
mel_emb = mel_emb + self . mel_pos_embedding ( torch . arange ( mel_emb . shape [ 1 ] , device = mel_emb . device ) )
if text_emb . shape [ 0 ] != mel_emb . shape [ 0 ] :
text_emb = text_emb . repeat ( mel_emb . shape [ 0 ] , 1 , 1 )
emb = torch . cat ( [ text_emb , mel_emb ] , dim = 1 )
enc = self . gpt ( emb )
mel_logits = self . final_norm ( enc [ : , text_emb . shape [ 1 ] : ] )
mel_logits = self . mel_head ( mel_logits )
2021-08-09 17:56:06 +00:00
topk = sampler_fn ( F . softmax ( temperature * mel_logits [ : , - 1 ] , dim = - 1 ) , k = beam_width )
2021-08-09 05:22:42 +00:00
probabilities = ( probabilities . repeat_interleave ( beam_width , dim = 0 ) * topk . values . flatten ( ) )
probabilities , sort_indices = torch . sort ( probabilities , descending = True )
probabilities = probabilities [ : beam_width ]
mel_seq = mel_seq . repeat_interleave ( beam_width , dim = 0 )
codes = topk . indices . flatten ( )
mel_seq = torch . cat ( [ mel_seq , codes . unsqueeze ( 1 ) ] , dim = 1 )
mel_seq = mel_seq [ sort_indices ]
mel_seq = mel_seq [ : beam_width ]
if torch . all ( torch . any ( mel_seq == self . MEL_STOP_TOKEN , dim = 1 ) ) :
break
if mel_seq . shape [ 1 ] > = self . max_mel_frames :
print ( " Warning! Encountered frame limit before a stop token. Output is likely wrong. " )
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
mel_seq = mel_seq [ 0 , 1 : - 1 ] . unsqueeze ( 0 ) # Pick most likely outcome, remove first and last tokens, which were artificially added for GPT
mel_seq = mel_seq * ( mel_seq < 512 ) # The DVAE doesn't understand BOS/EOS/PAD tokens.
return mel_seq
2021-08-08 17:38:52 +00:00
2021-07-28 02:33:30 +00:00
@register_model
def register_gpt_tts ( opt_net , opt ) :
2021-08-07 04:10:18 +00:00
return GptTts ( * * opt_get ( opt_net , [ ' kwargs ' ] , { } ) )
2021-07-28 02:33:30 +00:00
if __name__ == ' __main__ ' :
gpt = GptTts ( )
2021-08-04 03:08:27 +00:00
l1 , l2 , i = gpt ( torch . randint ( high = 24 , size = ( 2 , 60 ) ) ,
torch . tensor ( [ 55 , 58 ] ) ,
torch . randint ( high = 512 , size = ( 2 , 310 ) ) ,
torch . tensor ( [ 300 , 305 ] ) )
print ( i . shape )
2021-07-31 05:07:35 +00:00
2021-07-31 21:57:57 +00:00
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
#print(o.shape)
2021-07-31 05:07:35 +00:00