diff --git a/codes/models/gpt_voice/unified_voice.py b/codes/models/gpt_voice/unified_voice.py index 8b93ef87..6e5548a4 100644 --- a/codes/models/gpt_voice/unified_voice.py +++ b/codes/models/gpt_voice/unified_voice.py @@ -63,6 +63,18 @@ def null_position_embeddings(range, dim): return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) +def build_hf_gpt_transformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing): + gpt_config = GPT2Config(vocab_size=num_tokens, + n_positions=max_seq_len, + n_ctx=max_seq_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing) + return GPT2Model(gpt_config) + + class UnifiedGptVoice(nn.Module): """ Derived from GptTtsHf, but offers multiple modes of autoregressive operation: @@ -107,7 +119,8 @@ class UnifiedGptVoice(nn.Module): self.start_mel_token = start_mel_token self.stop_mel_token = stop_mel_token self.shuffle_conditioning = shuffle_conditioning - + self.layers = layers + self.heads = heads self.max_mel_tokens = max_mel_tokens self.max_text_tokens = max_text_tokens self.model_dim = model_dim @@ -117,16 +130,8 @@ class UnifiedGptVoice(nn.Module): self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim) self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim) - seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs - self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing) - self.gpt = GPT2Model(self.gpt_config) + self.seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs + self.gpt = build_hf_gpt_transformer(layers, model_dim, heads, number_mel_codes, self.seq_length, checkpointing) if train_solo_embeddings: self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True) self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True) @@ -314,7 +319,16 @@ class UnifiedGptVoice(nn.Module): def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs): if not hasattr(self, 'inference_model'): - self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head) + # TODO: Decouple gpt_config from this inference model. + gpt_config = GPT2Config(vocab_size=self.max_mel_tokens, + n_positions=self.seq_length, + n_ctx=self.seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True) + self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head) text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) @@ -332,7 +346,7 @@ class UnifiedGptVoice(nn.Module): fake_inputs[:,-1] = self.start_mel_token gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, - max_length=self.gpt_config.n_positions, **hf_generate_kwargs) + max_length=self.seq_length, **hf_generate_kwargs) return gen[:, fake_inputs.shape[1]:]