diff --git a/codes/models/gpt_voice/gpt_tts_hf.py b/codes/models/gpt_voice/gpt_tts_hf.py index 54bf6ffe..d59543b5 100644 --- a/codes/models/gpt_voice/gpt_tts_hf.py +++ b/codes/models/gpt_voice/gpt_tts_hf.py @@ -1,16 +1,10 @@ -import random -from time import time - import torch import torch.nn as nn import torch.nn.functional as F -from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2PreTrainedModel -from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions -from transformers.utils.model_parallel_utils import get_device_map, assert_device_map +from transformers import GPT2Model, GPT2Config from models.arch_util import AttentionBlock from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel -from models.gpt_voice.mini_encoder import AudioMiniEncoder from models.tacotron2.text import symbols from trainer.networks import register_model from utils.util import opt_get @@ -39,8 +33,8 @@ class ConditioningEncoder(nn.Module): class GptTtsHf(nn.Module): - NUMBER_TEXT_TOKENS = 10000 # The number of tokens produced by our bespoke BPE tokenizer. - START_TEXT_TOKEN = 9999 + NUMBER_TEXT_TOKENS = 256 # The number of tokens produced by our bespoke BPE tokenizer. + START_TEXT_TOKEN = 255 STOP_TEXT_TOKEN = 0 NUMBER_MEL_CODES = 8194 START_MEL_TOKEN = 8192 diff --git a/codes/models/gpt_voice/unified_voice.py b/codes/models/gpt_voice/unified_voice.py index 558a8d58..ec447a3a 100644 --- a/codes/models/gpt_voice/unified_voice.py +++ b/codes/models/gpt_voice/unified_voice.py @@ -41,17 +41,18 @@ class UnifiedGptVoice(nn.Module): - Voice conditioned on text """ - NUMBER_TEXT_TOKENS = 256 # The number of tokens produced by our bespoke BPE tokenizer. - START_TEXT_TOKEN = 255 - STOP_TEXT_TOKEN = 0 - NUMBER_MEL_CODES = 8194 - START_MEL_TOKEN = 8192 - STOP_MEL_TOKEN = 8193 - def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=120, max_mel_tokens=250, max_conditioning_inputs=3, - checkpointing=True, mel_length_compression=1024, max_conditioning_length=60): + checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256, + start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, + stop_mel_token=8193): super().__init__() + self.number_text_tokens = number_text_tokens + self.start_text_token = start_text_token + self.stop_text_token = stop_text_token + self.number_mel_codes = number_mel_codes + self.start_mel_token = start_mel_token + self.stop_mel_token = stop_mel_token self.max_mel_tokens = max_mel_tokens self.max_symbols_per_phrase = max_symbols_per_phrase @@ -59,23 +60,22 @@ class UnifiedGptVoice(nn.Module): self.max_conditioning_inputs = max_conditioning_inputs self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) - self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim) + self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens - self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing) + self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing) self.gpt = GPT2Model(self.gpt_config) self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) - self.mel_head = nn.Linear(model_dim, self.NUMBER_MEL_CODES) + self.text_head = nn.Linear(model_dim, self.number_text_tokens) + self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.max_conditioning_length = max_conditioning_length - def build_aligned_inputs_and_targets(self, input, start_token, stop_token): inp = F.pad(input, (1,0), value=start_token) tar = F.pad(input, (0,1), value=stop_token) @@ -92,7 +92,7 @@ class UnifiedGptVoice(nn.Module): for b in range(len(mel_lengths)): actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. if actual_end < mel_input_tokens.shape[-1]: - mel_input_tokens[b, actual_end:] = self.STOP_MEL_TOKEN + mel_input_tokens[b, actual_end:] = self.stop_mel_token return mel_input_tokens def randomly_permute_conditioning_input(self, speech_conditioning_input): @@ -118,11 +118,12 @@ class UnifiedGptVoice(nn.Module): return gpt_out.attentions enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input - first_logits = self.final_norm(enc[:, :first_inputs.shape[1]]) + enc = self.final_norm(enc) + first_logits = enc[:, :first_inputs.shape[1]] first_logits = first_head(first_logits) first_logits = first_logits.permute(0,2,1) if second_inputs is not None: - second_logits = self.final_norm(enc[:, -second_inputs.shape[1]:]) + second_logits = enc[:, -second_inputs.shape[1]:] second_logits = second_head(second_logits) second_logits = second_logits.permute(0,2,1) return first_logits, second_logits @@ -143,9 +144,9 @@ class UnifiedGptVoice(nn.Module): speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN) + text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) - mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN) + mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token) mel_emb = self.gpt.get_input_embeddings()(mel_inputs) if text_first: text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions) @@ -166,7 +167,7 @@ class UnifiedGptVoice(nn.Module): speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN) + text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head) loss_text = F.cross_entropy(text_logits, text_targets.long()) @@ -180,7 +181,7 @@ class UnifiedGptVoice(nn.Module): speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) - mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN) + mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token) mel_emb = self.gpt.get_input_embeddings()(mel_inputs) mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) @@ -190,8 +191,8 @@ class UnifiedGptVoice(nn.Module): if not hasattr(self, 'inference_model'): self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head) - text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN) + text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.stop_text_token) + text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) # Randomly permute the conditioning spectrogram, to destroy any structure present. @@ -202,9 +203,9 @@ class UnifiedGptVoice(nn.Module): self.inference_model.store_mel_emb(emb) fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device) - fake_inputs[:,-1] = self.START_MEL_TOKEN + fake_inputs[:,-1] = self.start_mel_token - gen = self.inference_model.generate(fake_inputs, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN, + gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, max_length=emb.shape[1]+self.max_mel_tokens, **hf_generate_kwargs) return gen[:, fake_inputs.shape[1]:] @@ -216,7 +217,7 @@ def register_unified_gpt_voice(opt_net, opt): if __name__ == '__main__': gpt = UnifiedGptVoice(model_dim=256, heads=4) - l = gpt(torch.randn(2, 120, 800), + l = gpt(torch.randn(2, 80, 800), torch.randint(high=len(symbols), size=(2,80)), torch.randint(high=8192, size=(2,250)), torch.tensor([150*256,195*256])) diff --git a/codes/scripts/audio/gen/use_gpt_tts.py b/codes/scripts/audio/gen/use_gpt_tts.py index 6f198e8c..4957a440 100644 --- a/codes/scripts/audio/gen/use_gpt_tts.py +++ b/codes/scripts/audio/gen/use_gpt_tts.py @@ -112,7 +112,7 @@ if __name__ == '__main__': num_return_sequences=args.num_samples, length_penalty=1, early_stopping=True) # Delete the GPT TTS model to free up GPU memory - stop_token = gpt.STOP_MEL_TOKEN + stop_token = gpt.stop_mel_token del gpt print("Loading DVAE..")