forked from mrq/DL-Art-School
More lossy fixes
This commit is contained in:
parent
dee34f096c
commit
2fb4213a3e
|
@ -34,10 +34,10 @@ class GptTtsHf(nn.Module):
|
|||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
|
||||
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 2, model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||
self.conditioning_embedding = nn.Parameter(torch.randn(1,model_dim), requires_grad=True)
|
||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
|
||||
seq_length = 4+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
|
||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
||||
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
|
||||
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
|
@ -59,7 +59,6 @@ class GptTtsHf(nn.Module):
|
|||
|
||||
|
||||
text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN)
|
||||
text_targets = F.pad(text_targets, (0,1), value=self.STOP_TEXT_TOKEN)
|
||||
text_emb = self.text_embedding(text_targets)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
||||
|
||||
|
@ -72,7 +71,6 @@ class GptTtsHf(nn.Module):
|
|||
conds = conds + self.conditioning_embedding
|
||||
|
||||
mel_targets = F.pad(mel_targets, (1,0), value=self.START_MEL_TOKEN)
|
||||
mel_targets = F.pad(mel_targets, (0,1), value=self.STOP_MEL_TOKEN)
|
||||
mel_emb = self.gpt.get_input_embeddings()(mel_targets)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_targets.device))
|
||||
|
||||
|
@ -82,10 +80,10 @@ class GptTtsHf(nn.Module):
|
|||
return gpt_out.attentions
|
||||
enc = gpt_out.last_hidden_state
|
||||
|
||||
text_logits = self.final_norm(enc[:, :self.max_symbols_per_phrase])
|
||||
text_logits = self.final_norm(enc[:, :self.max_symbols_per_phrase+1])
|
||||
text_logits = self.text_head(text_logits)
|
||||
text_logits = text_logits.permute(0,2,1)
|
||||
mel_logits = self.final_norm(enc[:, -self.max_mel_tokens:])
|
||||
mel_logits = self.final_norm(enc[:, -(self.max_mel_tokens+1):])
|
||||
mel_logits = self.mel_head(mel_logits)
|
||||
mel_logits = mel_logits.permute(0,2,1)
|
||||
|
||||
|
@ -109,9 +107,9 @@ class GptTtsHf(nn.Module):
|
|||
if return_attentions:
|
||||
return mel_logits
|
||||
|
||||
text_targets = F.pad(text_inputs, (0,self.max_symbols_per_phrase-text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
|
||||
text_targets = F.pad(text_inputs, (0,self.max_symbols_per_phrase-text_inputs.shape[1]+1), value=self.STOP_TEXT_TOKEN)
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
mel_targets = F.pad(mel_targets, (0,self.max_mel_tokens-mel_targets.shape[1]), value=self.STOP_MEL_TOKEN)
|
||||
mel_targets = F.pad(mel_targets, (0,self.max_mel_tokens-mel_targets.shape[1]+1), value=self.STOP_MEL_TOKEN)
|
||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user