2021-07-28 02:33:30 +00:00
import torch
import torch . nn as nn
import torch . nn . functional as F
2021-08-05 02:07:45 +00:00
from models . gpt_voice . lucidrains_gpt import Transformer
2021-07-28 02:33:30 +00:00
from models . tacotron2 . taco_utils import get_mask_from_lengths
from models . tacotron2 . text import symbols
from trainer . networks import register_model
class GptTts ( nn . Module ) :
2021-08-05 02:07:45 +00:00
MAX_SYMBOLS_PER_PHRASE = 200
NUMBER_SYMBOLS = len ( symbols )
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
2021-08-04 03:08:27 +00:00
MEL_DICTIONARY_SIZE = 512 + 3
MEL_START_TOKEN = MEL_DICTIONARY_SIZE - 3
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE - 2
2021-07-28 02:33:30 +00:00
def __init__ ( self ) :
super ( ) . __init__ ( )
model_dim = 512
2021-08-05 02:07:45 +00:00
max_mel_frames = 900 * 3 / / 8 # 900 is the max number of MEL frames. The VQVAE outputs 3/8 of the input mel as tokens.
2021-07-28 02:33:30 +00:00
2021-07-31 05:07:35 +00:00
self . model_dim = model_dim
self . max_mel_frames = max_mel_frames
2021-08-05 02:07:45 +00:00
self . text_embedding = nn . Embedding ( self . NUMBER_TEXT_TOKENS , model_dim )
2021-08-04 03:08:27 +00:00
self . mel_embedding = nn . Embedding ( self . MEL_DICTIONARY_SIZE , model_dim )
2021-08-05 02:07:45 +00:00
self . text_pos_embedding = nn . Embedding ( self . MAX_SYMBOLS_PER_PHRASE , model_dim )
2021-08-04 03:08:27 +00:00
self . mel_pos_embedding = nn . Embedding ( max_mel_frames , model_dim )
2021-08-05 02:07:45 +00:00
#self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False)
self . gpt = Transformer ( dim = model_dim , depth = 8 , seq_len = 1 + self . MAX_SYMBOLS_PER_PHRASE + max_mel_frames , heads = 8 )
2021-08-04 03:08:27 +00:00
self . final_norm = nn . LayerNorm ( model_dim )
2021-08-05 02:07:45 +00:00
self . text_head = nn . Linear ( model_dim , self . NUMBER_TEXT_TOKENS )
2021-08-04 03:08:27 +00:00
self . mel_head = nn . Linear ( model_dim , self . MEL_DICTIONARY_SIZE )
def forward ( self , text_inputs , text_lengths , mel_targets , output_lengths ) :
2021-07-28 02:33:30 +00:00
text_emb = self . text_embedding ( text_inputs )
2021-08-04 03:08:27 +00:00
text_emb = text_emb + self . text_pos_embedding ( torch . arange ( text_inputs . shape [ 1 ] , device = text_inputs . device ) )
mel_emb = self . mel_embedding ( mel_targets )
mel_emb = mel_emb + self . mel_pos_embedding ( torch . arange ( mel_targets . shape [ 1 ] , device = mel_targets . device ) )
emb = torch . cat ( [ text_emb , mel_emb ] , dim = 1 )
enc = self . gpt ( emb )
# Compute logits for text and mel heads
text_logits = self . final_norm ( enc [ : , : text_emb . shape [ 1 ] ] )
mel_logits = self . final_norm ( enc [ : , text_emb . shape [ 1 ] : ] )
2021-08-05 02:07:45 +00:00
text_logits = self . text_head ( text_logits )
2021-08-04 03:08:27 +00:00
mel_logits = self . mel_head ( mel_logits )
# Compute loss
2021-08-04 16:28:09 +00:00
text_targets = text_inputs [ : , 1 : ]
text_logits = text_logits . permute ( 0 , 2 , 1 ) [ : , : , : - 1 ] # The last element of the logits is unneeded because the input to the transformer contains a <EOS> token for both text and mel.
loss_text = F . cross_entropy ( text_logits , text_targets , reduction = ' none ' )
mel_targets = mel_targets [ : , 1 : ]
mel_logits = mel_logits . permute ( 0 , 2 , 1 ) [ : , : , : - 1 ]
loss_mel = F . cross_entropy ( mel_logits , mel_targets , reduction = ' none ' )
2021-08-04 06:44:04 +00:00
2021-08-04 03:08:27 +00:00
# Fix up mel_logits so it can go into a VAE decoder as well.
2021-08-04 16:28:09 +00:00
mel_codes = torch . argmax ( F . softmax ( mel_logits , dim = 1 ) , dim = 1 )
2021-08-05 02:07:45 +00:00
mel_pad_mask = ~ get_mask_from_lengths ( output_lengths - 1 , mel_targets . shape [ 1 ] )
2021-08-04 16:28:09 +00:00
mel_codes = mel_codes * torch . ones_like ( mel_codes ) . masked_fill_ ( mel_pad_mask , 0 )
2021-08-05 02:07:45 +00:00
mel_codes = mel_codes [ : , : - 1 ] # Strip off <EOS> token too (or padding). The important part is that the output sequence length is identical to the VAE input.
2021-08-04 03:08:27 +00:00
extra_mask = mel_codes < self . MEL_DICTIONARY_SIZE - 3 # The VAE doesn't know about START/STOP/PAD
mel_codes = mel_codes * extra_mask
2021-08-05 02:07:45 +00:00
# This class also returns the mel_targets for validation purposes. Format those.
mel_targets = mel_targets [ : , : - 1 ]
mel_targets = mel_targets * ( mel_targets < self . MEL_DICTIONARY_SIZE - 3 )
return loss_text . mean ( ) , loss_mel . mean ( ) , mel_codes , mel_targets
2021-07-31 05:07:35 +00:00
2021-08-04 06:44:04 +00:00
def inference ( self , text_inputs ) :
2021-07-31 05:07:35 +00:00
text_emb = self . text_embedding ( text_inputs )
2021-08-04 06:44:04 +00:00
text_emb = text_emb + self . text_pos_embedding ( torch . arange ( text_inputs . shape [ 1 ] , device = text_inputs . device ) )
mel_seq = [ self . MEL_START_TOKEN , 0 ]
while mel_seq [ - 1 ] != self . MEL_STOP_TOKEN and len ( mel_seq ) < self . max_mel_frames :
2021-08-04 16:28:09 +00:00
mel_seq . append ( 0 )
mel_emb = self . mel_embedding ( torch . tensor ( mel_seq , dtype = torch . long , device = text_inputs . device ) ) . unsqueeze ( 0 )
mel_emb = mel_emb + self . mel_pos_embedding ( torch . arange ( mel_emb . shape [ 1 ] , device = mel_emb . device ) )
2021-08-04 06:44:04 +00:00
emb = torch . cat ( [ text_emb , mel_emb ] , dim = 1 )
enc = self . gpt ( emb )
mel_logits = self . final_norm ( enc [ : , text_emb . shape [ 1 ] : ] )
mel_logits = self . mel_head ( mel_logits )
mel_codes = torch . argmax ( F . softmax ( mel_logits , dim = - 1 ) , dim = - 1 )
mel_seq [ - 1 ] = mel_codes [ - 1 ]
if len ( mel_seq ) > = self . max_mel_frames :
print ( " Warning! Encountered frame limit before a stop token. Output is likely wrong. " )
2021-08-04 16:28:09 +00:00
# Prevent sending invalid tokens to the VAE
mel_seq = [ s if s < 512 else 0 for s in mel_seq ]
2021-08-04 06:44:04 +00:00
return mel_seq [ : - 1 ]
2021-07-31 05:07:35 +00:00
2021-07-28 02:33:30 +00:00
@register_model
def register_gpt_tts ( opt_net , opt ) :
return GptTts ( )
if __name__ == ' __main__ ' :
gpt = GptTts ( )
2021-08-04 03:08:27 +00:00
l1 , l2 , i = gpt ( torch . randint ( high = 24 , size = ( 2 , 60 ) ) ,
torch . tensor ( [ 55 , 58 ] ) ,
torch . randint ( high = 512 , size = ( 2 , 310 ) ) ,
torch . tensor ( [ 300 , 305 ] ) )
print ( i . shape )
2021-07-31 05:07:35 +00:00
2021-07-31 21:57:57 +00:00
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
#print(o.shape)
2021-07-31 05:07:35 +00:00