unified_voice: begin decoupling from HF GPT

I'd like to try some different (newer) transformer variants. The way to get
there is softly decoupling the transformer portion of this architecture
from GPT. This actually should be fairly easy.
This commit is contained in:
James Betker 2022-01-07 22:51:24 -07:00
parent 1f6a5310b8
commit 34774f9948

View File

@ -63,6 +63,18 @@ def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
def build_hf_gpt_transformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing):
gpt_config = GPT2Config(vocab_size=num_tokens,
n_positions=max_seq_len,
n_ctx=max_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
return GPT2Model(gpt_config)
class UnifiedGptVoice(nn.Module):
"""
Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
@ -107,7 +119,8 @@ class UnifiedGptVoice(nn.Module):
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
self.shuffle_conditioning = shuffle_conditioning
self.layers = layers
self.heads = heads
self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens
self.model_dim = model_dim
@ -117,16 +130,8 @@ class UnifiedGptVoice(nn.Module):
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim)
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
self.gpt = GPT2Model(self.gpt_config)
self.seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
self.gpt = build_hf_gpt_transformer(layers, model_dim, heads, number_mel_codes, self.seq_length, checkpointing)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
@ -314,7 +319,16 @@ class UnifiedGptVoice(nn.Module):
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
# TODO: Decouple gpt_config from this inference model.
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
n_positions=self.seq_length,
n_ctx=self.seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True)
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
@ -332,7 +346,7 @@ class UnifiedGptVoice(nn.Module):
fake_inputs[:,-1] = self.start_mel_token
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
max_length=self.gpt_config.n_positions, **hf_generate_kwargs)
max_length=self.seq_length, **hf_generate_kwargs)
return gen[:, fake_inputs.shape[1]:]