From 34774f9948dbcb48c34eca43ad8ed759866119be Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 7 Jan 2022 22:51:24 -0700 Subject: [PATCH] unified_voice: begin decoupling from HF GPT I'd like to try some different (newer) transformer variants. The way to get there is softly decoupling the transformer portion of this architecture from GPT. This actually should be fairly easy. --- codes/models/gpt_voice/unified_voice.py | 40 +++++++++++++++++-------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/codes/models/gpt_voice/unified_voice.py b/codes/models/gpt_voice/unified_voice.py index 8b93ef87..6e5548a4 100644 --- a/codes/models/gpt_voice/unified_voice.py +++ b/codes/models/gpt_voice/unified_voice.py @@ -63,6 +63,18 @@ def null_position_embeddings(range, dim): return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) +def build_hf_gpt_transformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing): + gpt_config = GPT2Config(vocab_size=num_tokens, + n_positions=max_seq_len, + n_ctx=max_seq_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing) + return GPT2Model(gpt_config) + + class UnifiedGptVoice(nn.Module): """ Derived from GptTtsHf, but offers multiple modes of autoregressive operation: @@ -107,7 +119,8 @@ class UnifiedGptVoice(nn.Module): self.start_mel_token = start_mel_token self.stop_mel_token = stop_mel_token self.shuffle_conditioning = shuffle_conditioning - + self.layers = layers + self.heads = heads self.max_mel_tokens = max_mel_tokens self.max_text_tokens = max_text_tokens self.model_dim = model_dim @@ -117,16 +130,8 @@ class UnifiedGptVoice(nn.Module): self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim) self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim) - seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs - self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing) - self.gpt = GPT2Model(self.gpt_config) + self.seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs + self.gpt = build_hf_gpt_transformer(layers, model_dim, heads, number_mel_codes, self.seq_length, checkpointing) if train_solo_embeddings: self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True) self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True) @@ -314,7 +319,16 @@ class UnifiedGptVoice(nn.Module): def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs): if not hasattr(self, 'inference_model'): - self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head) + # TODO: Decouple gpt_config from this inference model. + gpt_config = GPT2Config(vocab_size=self.max_mel_tokens, + n_positions=self.seq_length, + n_ctx=self.seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True) + self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head) text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) @@ -332,7 +346,7 @@ class UnifiedGptVoice(nn.Module): fake_inputs[:,-1] = self.start_mel_token gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, - max_length=self.gpt_config.n_positions, **hf_generate_kwargs) + max_length=self.seq_length, **hf_generate_kwargs) return gen[:, fake_inputs.shape[1]:]