forked from mrq/DL-Art-School
unified_voice: introduce paired embeddings
This commit is contained in:
parent
6996dfd9d5
commit
a698d3f525
|
@ -68,8 +68,10 @@ class UnifiedGptVoice(nn.Module):
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
self.text_pos_solo_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
self.text_pos_paired_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||||
|
self.mel_pos_solo_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
||||||
|
self.mel_pos_paired_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
||||||
seq_length = 2+self.max_total_tokens+self.max_conditioning_inputs
|
seq_length = 2+self.max_total_tokens+self.max_conditioning_inputs
|
||||||
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
||||||
n_positions=seq_length,
|
n_positions=seq_length,
|
||||||
|
@ -90,7 +92,8 @@ class UnifiedGptVoice(nn.Module):
|
||||||
self.max_conditioning_length = max_conditioning_length
|
self.max_conditioning_length = max_conditioning_length
|
||||||
|
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
for module in [self.text_embedding, self.text_pos_embedding, self.mel_pos_embedding]:
|
for module in [self.text_embedding, self.text_pos_solo_embedding, self.text_pos_paired_embedding,
|
||||||
|
self.mel_pos_solo_embedding, self.mel_pos_paired_embedding]:
|
||||||
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
|
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
|
||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
@ -168,10 +171,10 @@ class UnifiedGptVoice(nn.Module):
|
||||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||||
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
||||||
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_paired_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
if text_first:
|
if text_first:
|
||||||
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
|
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
|
||||||
else:
|
else:
|
||||||
|
@ -194,7 +197,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_solo_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||||
text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head)
|
text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head)
|
||||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||||
return loss_text.mean()
|
return loss_text.mean()
|
||||||
|
@ -211,18 +214,17 @@ class UnifiedGptVoice(nn.Module):
|
||||||
|
|
||||||
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
||||||
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_solo_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
|
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
|
||||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||||
return loss_mel.mean()
|
return loss_mel.mean()
|
||||||
|
|
||||||
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
|
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_paired_embedding, self.final_norm, self.mel_head)
|
||||||
|
|
||||||
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.stop_text_token)
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||||
|
|
||||||
# Randomly permute the conditioning spectrogram, to destroy any structure present.
|
# Randomly permute the conditioning spectrogram, to destroy any structure present.
|
||||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||||
|
@ -235,7 +237,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
fake_inputs[:,-1] = self.start_mel_token
|
fake_inputs[:,-1] = self.start_mel_token
|
||||||
|
|
||||||
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||||
max_length=emb.shape[1]+self.max_mel_tokens, **hf_generate_kwargs)
|
max_length=self.gpt_config.n_positions, **hf_generate_kwargs)
|
||||||
return gen[:, fake_inputs.shape[1]:]
|
return gen[:, fake_inputs.shape[1]:]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user