Make tokenization configs more configurable

This commit is contained in:
James Betker 2021-12-25 12:17:50 -07:00
parent 52410fd9d9
commit ab9cafa572
3 changed files with 37 additions and 42 deletions

View File

@ -1,16 +1,10 @@
import random
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from transformers import GPT2Model, GPT2Config
from models.arch_util import AttentionBlock
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
from models.gpt_voice.mini_encoder import AudioMiniEncoder
from models.tacotron2.text import symbols
from trainer.networks import register_model
from utils.util import opt_get
@ -39,8 +33,8 @@ class ConditioningEncoder(nn.Module):
class GptTtsHf(nn.Module):
NUMBER_TEXT_TOKENS = 10000 # The number of tokens produced by our bespoke BPE tokenizer.
START_TEXT_TOKEN = 9999
NUMBER_TEXT_TOKENS = 256 # The number of tokens produced by our bespoke BPE tokenizer.
START_TEXT_TOKEN = 255
STOP_TEXT_TOKEN = 0
NUMBER_MEL_CODES = 8194
START_MEL_TOKEN = 8192

View File

@ -41,17 +41,18 @@ class UnifiedGptVoice(nn.Module):
- Voice conditioned on text
"""
NUMBER_TEXT_TOKENS = 256 # The number of tokens produced by our bespoke BPE tokenizer.
START_TEXT_TOKEN = 255
STOP_TEXT_TOKEN = 0
NUMBER_MEL_CODES = 8194
START_MEL_TOKEN = 8192
STOP_MEL_TOKEN = 8193
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=120, max_mel_tokens=250, max_conditioning_inputs=3,
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60):
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256,
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193):
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
self.max_mel_tokens = max_mel_tokens
self.max_symbols_per_phrase = max_symbols_per_phrase
@ -59,9 +60,9 @@ class UnifiedGptVoice(nn.Module):
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=model_dim,
@ -71,11 +72,10 @@ class UnifiedGptVoice(nn.Module):
use_cache=not checkpointing)
self.gpt = GPT2Model(self.gpt_config)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
self.mel_head = nn.Linear(model_dim, self.NUMBER_MEL_CODES)
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
self.max_conditioning_length = max_conditioning_length
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
@ -92,7 +92,7 @@ class UnifiedGptVoice(nn.Module):
for b in range(len(mel_lengths)):
actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
if actual_end < mel_input_tokens.shape[-1]:
mel_input_tokens[b, actual_end:] = self.STOP_MEL_TOKEN
mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens
def randomly_permute_conditioning_input(self, speech_conditioning_input):
@ -118,11 +118,12 @@ class UnifiedGptVoice(nn.Module):
return gpt_out.attentions
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
first_logits = self.final_norm(enc[:, :first_inputs.shape[1]])
enc = self.final_norm(enc)
first_logits = enc[:, :first_inputs.shape[1]]
first_logits = first_head(first_logits)
first_logits = first_logits.permute(0,2,1)
if second_inputs is not None:
second_logits = self.final_norm(enc[:, -second_inputs.shape[1]:])
second_logits = enc[:, -second_inputs.shape[1]:]
second_logits = second_head(second_logits)
second_logits = second_logits.permute(0,2,1)
return first_logits, second_logits
@ -143,9 +144,9 @@ class UnifiedGptVoice(nn.Module):
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs)
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN)
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
if text_first:
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
@ -166,7 +167,7 @@ class UnifiedGptVoice(nn.Module):
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs)
text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head)
loss_text = F.cross_entropy(text_logits, text_targets.long())
@ -180,7 +181,7 @@ class UnifiedGptVoice(nn.Module):
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN)
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
@ -190,8 +191,8 @@ class UnifiedGptVoice(nn.Module):
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head)
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs)
# Randomly permute the conditioning spectrogram, to destroy any structure present.
@ -202,9 +203,9 @@ class UnifiedGptVoice(nn.Module):
self.inference_model.store_mel_emb(emb)
fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.START_MEL_TOKEN
fake_inputs[:,-1] = self.start_mel_token
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
max_length=emb.shape[1]+self.max_mel_tokens, **hf_generate_kwargs)
return gen[:, fake_inputs.shape[1]:]
@ -216,7 +217,7 @@ def register_unified_gpt_voice(opt_net, opt):
if __name__ == '__main__':
gpt = UnifiedGptVoice(model_dim=256, heads=4)
l = gpt(torch.randn(2, 120, 800),
l = gpt(torch.randn(2, 80, 800),
torch.randint(high=len(symbols), size=(2,80)),
torch.randint(high=8192, size=(2,250)),
torch.tensor([150*256,195*256]))

View File

@ -112,7 +112,7 @@ if __name__ == '__main__':
num_return_sequences=args.num_samples, length_penalty=1, early_stopping=True)
# Delete the GPT TTS model to free up GPU memory
stop_token = gpt.STOP_MEL_TOKEN
stop_token = gpt.stop_mel_token
del gpt
print("Loading DVAE..")