Remove dedicated positioning embeddings

This commit is contained in:
James Betker 2021-12-19 09:01:31 -07:00
parent b4ddcd7111
commit c813befd53

View File

@ -32,9 +32,6 @@ class GptTtsHf(nn.Module):
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
self.conditioning_embedding = nn.Parameter(torch.randn(1,model_dim), requires_grad=True)
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
n_positions=seq_length,
@ -57,7 +54,6 @@ class GptTtsHf(nn.Module):
def get_logits(self, text_inputs, cond_inputs, mel_inputs, get_attns=False):
text_emb = self.text_embedding(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
conds = []
for k in range(cond_inputs.shape[1]):
@ -65,10 +61,8 @@ class GptTtsHf(nn.Module):
while len(conds) < self.max_conditioning_inputs:
conds.append(conds[-1])
conds = torch.stack(conds, dim=1)
conds = conds + self.conditioning_embedding
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_inputs.device))
emb = torch.cat([text_emb, conds, mel_emb], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
@ -117,7 +111,6 @@ class GptTtsHf(nn.Module):
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
text_emb = self.text_embedding(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
conds = []
for k in range(cond_inputs.shape[1]):
@ -125,7 +118,6 @@ class GptTtsHf(nn.Module):
while len(conds) < self.max_conditioning_inputs:
conds.append(conds[-1])
conds = torch.stack(conds, dim=1)
conds = conds + self.conditioning_embedding
emb = torch.cat([text_emb, conds], dim=1)
self.inference_model.store_mel_emb(emb)