From ec456b67336024111ca6d632c34555b3961a9211 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 9 Jan 2022 22:34:30 -0700 Subject: [PATCH] Revert unified_voice back to beginning I'll be doing my work within unified_voice2 --- codes/models/gpt_voice/unified_voice.py | 40 ++++++++++++++----------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/codes/models/gpt_voice/unified_voice.py b/codes/models/gpt_voice/unified_voice.py index 1689d13e..8b93ef87 100644 --- a/codes/models/gpt_voice/unified_voice.py +++ b/codes/models/gpt_voice/unified_voice.py @@ -8,7 +8,6 @@ from transformers import GPT2Model, GPT2Config from models.arch_util import AttentionBlock from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel from models.gpt_voice.gpt_asr_hf2 import ResBlock -from models.gpt_voice.transformer_builders import build_hf_gpt_transformer from models.tacotron2.text import symbols from trainer.networks import register_model from utils.util import opt_get @@ -60,6 +59,10 @@ class MelEncoder(nn.Module): return x.permute(0,2,1) +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + class UnifiedGptVoice(nn.Module): """ Derived from GptTtsHf, but offers multiple modes of autoregressive operation: @@ -104,8 +107,7 @@ class UnifiedGptVoice(nn.Module): self.start_mel_token = start_mel_token self.stop_mel_token = stop_mel_token self.shuffle_conditioning = shuffle_conditioning - self.layers = layers - self.heads = heads + self.max_mel_tokens = max_mel_tokens self.max_text_tokens = max_text_tokens self.model_dim = model_dim @@ -115,14 +117,25 @@ class UnifiedGptVoice(nn.Module): self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim) self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim) - self.seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs - self.gpt = build_hf_gpt_transformer(layers, model_dim, heads, number_mel_codes, self.seq_length, checkpointing) + seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs + self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing) + self.gpt = GPT2Model(self.gpt_config) if train_solo_embeddings: - self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) - self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True) else: self.mel_solo_embedding = 0 self.text_solo_embedding = 0 + # Override the built in positional embeddings + del self.gpt.wpe + self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) if not use_mel_codes_as_input: self.gpt.wte = MelEncoder(model_dim, resblocks_per_reduction=1) @@ -301,16 +314,7 @@ class UnifiedGptVoice(nn.Module): def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs): if not hasattr(self, 'inference_model'): - # TODO: Decouple gpt_config from this inference model. - gpt_config = GPT2Config(vocab_size=self.max_mel_tokens, - n_positions=self.seq_length, - n_ctx=self.seq_length, - n_embd=self.model_dim, - n_layer=self.layers, - n_head=self.heads, - gradient_checkpointing=False, - use_cache=True) - self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head) + self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head) text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) @@ -328,7 +332,7 @@ class UnifiedGptVoice(nn.Module): fake_inputs[:,-1] = self.start_mel_token gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, - max_length=self.seq_length, **hf_generate_kwargs) + max_length=self.gpt_config.n_positions, **hf_generate_kwargs) return gen[:, fake_inputs.shape[1]:]