2021-07-28 02:33:30 +00:00
import torch
import torch . nn as nn
import torch . nn . functional as F
2021-08-05 02:07:45 +00:00
from models . gpt_voice . lucidrains_gpt import Transformer
2021-08-05 13:07:17 +00:00
from models . gpt_voice . min_gpt import GPT , GPTConfig
2021-07-28 02:33:30 +00:00
from models . tacotron2 . taco_utils import get_mask_from_lengths
from models . tacotron2 . text import symbols
from trainer . networks import register_model
class GptTts ( nn . Module ) :
2021-08-05 02:07:45 +00:00
MAX_SYMBOLS_PER_PHRASE = 200
NUMBER_SYMBOLS = len ( symbols )
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
2021-08-05 13:07:17 +00:00
MEL_DICTIONARY_SIZE = 1024 + 3
2021-08-04 03:08:27 +00:00
MEL_START_TOKEN = MEL_DICTIONARY_SIZE - 3
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE - 2
2021-07-28 02:33:30 +00:00
def __init__ ( self ) :
super ( ) . __init__ ( )
model_dim = 512
2021-08-05 02:07:45 +00:00
max_mel_frames = 900 * 3 / / 8 # 900 is the max number of MEL frames. The VQVAE outputs 3/8 of the input mel as tokens.
2021-07-28 02:33:30 +00:00
2021-07-31 05:07:35 +00:00
self . model_dim = model_dim
self . max_mel_frames = max_mel_frames
2021-08-05 02:07:45 +00:00
self . text_embedding = nn . Embedding ( self . NUMBER_TEXT_TOKENS , model_dim )
2021-08-04 03:08:27 +00:00
self . mel_embedding = nn . Embedding ( self . MEL_DICTIONARY_SIZE , model_dim )
2021-08-05 02:07:45 +00:00
self . text_pos_embedding = nn . Embedding ( self . MAX_SYMBOLS_PER_PHRASE , model_dim )
2021-08-04 03:08:27 +00:00
self . mel_pos_embedding = nn . Embedding ( max_mel_frames , model_dim )
2021-08-05 13:07:17 +00:00
self . gpt = GPT ( GPTConfig ( 1 + self . MAX_SYMBOLS_PER_PHRASE + max_mel_frames , n_layer = 8 , n_embd = model_dim , n_head = 8 ) , do_pos_emb = False )
#self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8)
2021-08-04 03:08:27 +00:00
self . final_norm = nn . LayerNorm ( model_dim )
2021-08-05 02:07:45 +00:00
self . text_head = nn . Linear ( model_dim , self . NUMBER_TEXT_TOKENS )
2021-08-04 03:08:27 +00:00
self . mel_head = nn . Linear ( model_dim , self . MEL_DICTIONARY_SIZE )
def forward ( self , text_inputs , text_lengths , mel_targets , output_lengths ) :
2021-07-28 02:33:30 +00:00
text_emb = self . text_embedding ( text_inputs )
2021-08-04 03:08:27 +00:00
text_emb = text_emb + self . text_pos_embedding ( torch . arange ( text_inputs . shape [ 1 ] , device = text_inputs . device ) )
mel_emb = self . mel_embedding ( mel_targets )
mel_emb = mel_emb + self . mel_pos_embedding ( torch . arange ( mel_targets . shape [ 1 ] , device = mel_targets . device ) )
emb = torch . cat ( [ text_emb , mel_emb ] , dim = 1 )
enc = self . gpt ( emb )
# Compute logits for text and mel heads
text_logits = self . final_norm ( enc [ : , : text_emb . shape [ 1 ] ] )
mel_logits = self . final_norm ( enc [ : , text_emb . shape [ 1 ] : ] )
2021-08-05 02:07:45 +00:00
text_logits = self . text_head ( text_logits )
2021-08-04 03:08:27 +00:00
mel_logits = self . mel_head ( mel_logits )
# Compute loss
2021-08-04 16:28:09 +00:00
text_targets = text_inputs [ : , 1 : ]
text_logits = text_logits . permute ( 0 , 2 , 1 ) [ : , : , : - 1 ] # The last element of the logits is unneeded because the input to the transformer contains a <EOS> token for both text and mel.
loss_text = F . cross_entropy ( text_logits , text_targets , reduction = ' none ' )
mel_targets = mel_targets [ : , 1 : ]
mel_logits = mel_logits . permute ( 0 , 2 , 1 ) [ : , : , : - 1 ]
loss_mel = F . cross_entropy ( mel_logits , mel_targets , reduction = ' none ' )
2021-08-04 06:44:04 +00:00
2021-08-04 03:08:27 +00:00
# Fix up mel_logits so it can go into a VAE decoder as well.
2021-08-04 16:28:09 +00:00
mel_codes = torch . argmax ( F . softmax ( mel_logits , dim = 1 ) , dim = 1 )
2021-08-05 02:07:45 +00:00
mel_pad_mask = ~ get_mask_from_lengths ( output_lengths - 1 , mel_targets . shape [ 1 ] )
2021-08-04 16:28:09 +00:00
mel_codes = mel_codes * torch . ones_like ( mel_codes ) . masked_fill_ ( mel_pad_mask , 0 )
2021-08-05 02:07:45 +00:00
mel_codes = mel_codes [ : , : - 1 ] # Strip off <EOS> token too (or padding). The important part is that the output sequence length is identical to the VAE input.
2021-08-04 03:08:27 +00:00
extra_mask = mel_codes < self . MEL_DICTIONARY_SIZE - 3 # The VAE doesn't know about START/STOP/PAD
mel_codes = mel_codes * extra_mask
2021-08-05 02:07:45 +00:00
# This class also returns the mel_targets for validation purposes. Format those.
mel_targets = mel_targets [ : , : - 1 ]
mel_targets = mel_targets * ( mel_targets < self . MEL_DICTIONARY_SIZE - 3 )
return loss_text . mean ( ) , loss_mel . mean ( ) , mel_codes , mel_targets
2021-07-31 05:07:35 +00:00
2021-08-04 06:44:04 +00:00
def inference ( self , text_inputs ) :
2021-07-31 05:07:35 +00:00
text_emb = self . text_embedding ( text_inputs )
2021-08-04 06:44:04 +00:00
text_emb = text_emb + self . text_pos_embedding ( torch . arange ( text_inputs . shape [ 1 ] , device = text_inputs . device ) )
2021-08-05 12:46:30 +00:00
mel_seq = torch . full ( ( text_emb . shape [ 0 ] , 1 ) , fill_value = self . MEL_START_TOKEN , device = text_emb . device )
stop_encountered = torch . zeros ( ( text_emb . shape [ 0 ] , ) , device = text_emb . device )
while not torch . all ( stop_encountered ) and len ( mel_seq ) < self . max_mel_frames :
mel_emb = self . mel_embedding ( mel_seq )
2021-08-04 16:28:09 +00:00
mel_emb = mel_emb + self . mel_pos_embedding ( torch . arange ( mel_emb . shape [ 1 ] , device = mel_emb . device ) )
2021-08-04 06:44:04 +00:00
emb = torch . cat ( [ text_emb , mel_emb ] , dim = 1 )
enc = self . gpt ( emb )
mel_logits = self . final_norm ( enc [ : , text_emb . shape [ 1 ] : ] )
mel_logits = self . mel_head ( mel_logits )
mel_codes = torch . argmax ( F . softmax ( mel_logits , dim = - 1 ) , dim = - 1 )
2021-08-05 12:46:30 +00:00
mel_seq = torch . cat ( [ mel_seq , mel_codes [ : , - 1 ] . unsqueeze ( 1 ) ] , dim = 1 )
stop_encountered = torch . logical_or ( stop_encountered , mel_seq [ : , - 1 ] == self . MEL_STOP_TOKEN )
2021-08-04 06:44:04 +00:00
if len ( mel_seq ) > = self . max_mel_frames :
print ( " Warning! Encountered frame limit before a stop token. Output is likely wrong. " )
2021-08-05 12:46:30 +00:00
# Prevent sending invalid tokens to the VAE. Also pad to a length of 3, which is required by the DVAE.
mel_seq = mel_seq * ( mel_seq < 512 )
padding_needed = 3 - ( mel_seq . shape [ 1 ] % 3 )
mel_seq = F . pad ( mel_seq , ( 0 , padding_needed ) )
return mel_seq
2021-07-31 05:07:35 +00:00
2021-07-28 02:33:30 +00:00
@register_model
def register_gpt_tts ( opt_net , opt ) :
return GptTts ( )
if __name__ == ' __main__ ' :
gpt = GptTts ( )
2021-08-04 03:08:27 +00:00
l1 , l2 , i = gpt ( torch . randint ( high = 24 , size = ( 2 , 60 ) ) ,
torch . tensor ( [ 55 , 58 ] ) ,
torch . randint ( high = 512 , size = ( 2 , 310 ) ) ,
torch . tensor ( [ 300 , 305 ] ) )
print ( i . shape )
2021-07-31 05:07:35 +00:00
2021-07-31 21:57:57 +00:00
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
#print(o.shape)
2021-07-31 05:07:35 +00:00