Add position embeddings back into unified_voice

I think this may be the solution behind the days problems.
This commit is contained in:
James Betker 2021-12-25 23:10:56 -07:00
parent 64cb4a92db
commit e959541494

View File

@ -41,7 +41,7 @@ class UnifiedGptVoice(nn.Module):
- Voice conditioned on text
"""
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=120, max_mel_tokens=250, max_conditioning_inputs=3,
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=120, max_mel_tokens=250, max_total_tokens=370, max_conditioning_inputs=3,
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256,
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193):
@ -56,12 +56,15 @@ class UnifiedGptVoice(nn.Module):
self.max_mel_tokens = max_mel_tokens
self.max_symbols_per_phrase = max_symbols_per_phrase
self.max_total_tokens = max_total_tokens
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
seq_length = 2+self.max_total_tokens+self.max_conditioning_inputs
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
n_positions=seq_length,
n_ctx=seq_length,
@ -145,9 +148,10 @@ class UnifiedGptVoice(nn.Module):
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
if text_first:
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
else:
@ -168,7 +172,7 @@ class UnifiedGptVoice(nn.Module):
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head)
loss_text = F.cross_entropy(text_logits, text_targets.long())
return loss_text.mean()
@ -183,17 +187,18 @@ class UnifiedGptVoice(nn.Module):
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_mel.mean()
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head)
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
# Randomly permute the conditioning spectrogram, to destroy any structure present.
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)