unified_voice: begin decoupling from HF GPT
I'd like to try some different (newer) transformer variants. The way to get there is softly decoupling the transformer portion of this architecture from GPT. This actually should be fairly easy.
This commit is contained in:
parent
1f6a5310b8
commit
34774f9948
|
@ -63,6 +63,18 @@ def null_position_embeddings(range, dim):
|
|||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
||||
|
||||
def build_hf_gpt_transformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing):
|
||||
gpt_config = GPT2Config(vocab_size=num_tokens,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing)
|
||||
return GPT2Model(gpt_config)
|
||||
|
||||
|
||||
class UnifiedGptVoice(nn.Module):
|
||||
"""
|
||||
Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
|
||||
|
@ -107,7 +119,8 @@ class UnifiedGptVoice(nn.Module):
|
|||
self.start_mel_token = start_mel_token
|
||||
self.stop_mel_token = stop_mel_token
|
||||
self.shuffle_conditioning = shuffle_conditioning
|
||||
|
||||
self.layers = layers
|
||||
self.heads = heads
|
||||
self.max_mel_tokens = max_mel_tokens
|
||||
self.max_text_tokens = max_text_tokens
|
||||
self.model_dim = model_dim
|
||||
|
@ -117,16 +130,8 @@ class UnifiedGptVoice(nn.Module):
|
|||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
|
||||
seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
||||
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing)
|
||||
self.gpt = GPT2Model(self.gpt_config)
|
||||
self.seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
||||
self.gpt = build_hf_gpt_transformer(layers, model_dim, heads, number_mel_codes, self.seq_length, checkpointing)
|
||||
if train_solo_embeddings:
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
|
||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
|
||||
|
@ -314,7 +319,16 @@ class UnifiedGptVoice(nn.Module):
|
|||
|
||||
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
||||
if not hasattr(self, 'inference_model'):
|
||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
|
||||
# TODO: Decouple gpt_config from this inference model.
|
||||
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
||||
n_positions=self.seq_length,
|
||||
n_ctx=self.seq_length,
|
||||
n_embd=self.model_dim,
|
||||
n_layer=self.layers,
|
||||
n_head=self.heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True)
|
||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
|
||||
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
|
@ -332,7 +346,7 @@ class UnifiedGptVoice(nn.Module):
|
|||
fake_inputs[:,-1] = self.start_mel_token
|
||||
|
||||
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||
max_length=self.gpt_config.n_positions, **hf_generate_kwargs)
|
||||
max_length=self.seq_length, **hf_generate_kwargs)
|
||||
return gen[:, fake_inputs.shape[1]:]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user