GptTtsHf: Make the input/target placement easier to reason about
This commit is contained in:
parent
2fb4213a3e
commit
9b9f7ea61b
|
@ -52,15 +52,14 @@ class GptTtsHf(nn.Module):
|
||||||
self.mel_head = nn.Linear(model_dim, self.NUMBER_MEL_CODES)
|
self.mel_head = nn.Linear(model_dim, self.NUMBER_MEL_CODES)
|
||||||
|
|
||||||
|
|
||||||
def get_logits(self, text_inputs, cond_inputs, mel_targets, get_attns=False):
|
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||||
assert text_inputs.shape[1] <= self.max_symbols_per_phrase
|
inp = F.pad(input, (1,0), value=start_token)
|
||||||
assert cond_inputs.shape[1] <= self.max_conditioning_inputs
|
tar = F.pad(input, (0,1), value=stop_token)
|
||||||
assert mel_targets.shape[1] <= self.max_mel_tokens
|
return inp, tar
|
||||||
|
|
||||||
|
def get_logits(self, text_inputs, cond_inputs, mel_inputs, get_attns=False):
|
||||||
text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN)
|
text_emb = self.text_embedding(text_inputs)
|
||||||
text_emb = self.text_embedding(text_targets)
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
|
||||||
|
|
||||||
conds = []
|
conds = []
|
||||||
for k in range(cond_inputs.shape[1]):
|
for k in range(cond_inputs.shape[1]):
|
||||||
|
@ -70,9 +69,8 @@ class GptTtsHf(nn.Module):
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
conds = conds + self.conditioning_embedding
|
conds = conds + self.conditioning_embedding
|
||||||
|
|
||||||
mel_targets = F.pad(mel_targets, (1,0), value=self.START_MEL_TOKEN)
|
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
||||||
mel_emb = self.gpt.get_input_embeddings()(mel_targets)
|
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_inputs.device))
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_targets.device))
|
|
||||||
|
|
||||||
emb = torch.cat([text_emb, conds, mel_emb], dim=1)
|
emb = torch.cat([text_emb, conds, mel_emb], dim=1)
|
||||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||||
|
@ -80,10 +78,10 @@ class GptTtsHf(nn.Module):
|
||||||
return gpt_out.attentions
|
return gpt_out.attentions
|
||||||
enc = gpt_out.last_hidden_state
|
enc = gpt_out.last_hidden_state
|
||||||
|
|
||||||
text_logits = self.final_norm(enc[:, :self.max_symbols_per_phrase+1])
|
text_logits = self.final_norm(enc[:, :text_emb.shape[1]])
|
||||||
text_logits = self.text_head(text_logits)
|
text_logits = self.text_head(text_logits)
|
||||||
text_logits = text_logits.permute(0,2,1)
|
text_logits = text_logits.permute(0,2,1)
|
||||||
mel_logits = self.final_norm(enc[:, -(self.max_mel_tokens+1):])
|
mel_logits = self.final_norm(enc[:, -mel_emb.shape[1]:])
|
||||||
mel_logits = self.mel_head(mel_logits)
|
mel_logits = self.mel_head(mel_logits)
|
||||||
mel_logits = mel_logits.permute(0,2,1)
|
mel_logits = mel_logits.permute(0,2,1)
|
||||||
|
|
||||||
|
@ -103,13 +101,12 @@ class GptTtsHf(nn.Module):
|
||||||
if mel_lengths[b] < mel_targets.shape[-1]:
|
if mel_lengths[b] < mel_targets.shape[-1]:
|
||||||
mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN
|
mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN
|
||||||
|
|
||||||
text_logits, mel_logits = self.get_logits(text_inputs, cond_inputs, mel_targets, get_attns=return_attentions)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
||||||
|
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_targets, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN)
|
||||||
|
text_logits, mel_logits = self.get_logits(text_inputs, cond_inputs, mel_inputs, get_attns=return_attentions)
|
||||||
if return_attentions:
|
if return_attentions:
|
||||||
return mel_logits
|
return mel_logits
|
||||||
|
|
||||||
text_targets = F.pad(text_inputs, (0,self.max_symbols_per_phrase-text_inputs.shape[1]+1), value=self.STOP_TEXT_TOKEN)
|
|
||||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||||
mel_targets = F.pad(mel_targets, (0,self.max_mel_tokens-mel_targets.shape[1]+1), value=self.STOP_MEL_TOKEN)
|
|
||||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||||
|
|
||||||
|
@ -117,10 +114,9 @@ class GptTtsHf(nn.Module):
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
|
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
|
||||||
|
|
||||||
text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
||||||
text_targets = F.pad(text_targets, (0,1), value=self.STOP_TEXT_TOKEN)
|
text_emb = self.text_embedding(text_inputs)
|
||||||
text_emb = self.text_embedding(text_targets)
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
|
||||||
|
|
||||||
conds = []
|
conds = []
|
||||||
for k in range(cond_inputs.shape[1]):
|
for k in range(cond_inputs.shape[1]):
|
||||||
|
@ -133,7 +129,7 @@ class GptTtsHf(nn.Module):
|
||||||
emb = torch.cat([text_emb, conds], dim=1)
|
emb = torch.cat([text_emb, conds], dim=1)
|
||||||
self.inference_model.store_mel_emb(emb)
|
self.inference_model.store_mel_emb(emb)
|
||||||
|
|
||||||
fake_inputs = torch.full((text_inputs.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||||
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
||||||
|
|
||||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
|
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user