Remove dedicated positioning embeddings
This commit is contained in:
parent
b4ddcd7111
commit
c813befd53
|
@ -32,9 +32,6 @@ class GptTtsHf(nn.Module):
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
|
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
|
||||||
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
||||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
|
||||||
self.conditioning_embedding = nn.Parameter(torch.randn(1,model_dim), requires_grad=True)
|
|
||||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
|
||||||
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
|
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
|
||||||
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
|
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
|
||||||
n_positions=seq_length,
|
n_positions=seq_length,
|
||||||
|
@ -57,7 +54,6 @@ class GptTtsHf(nn.Module):
|
||||||
|
|
||||||
def get_logits(self, text_inputs, cond_inputs, mel_inputs, get_attns=False):
|
def get_logits(self, text_inputs, cond_inputs, mel_inputs, get_attns=False):
|
||||||
text_emb = self.text_embedding(text_inputs)
|
text_emb = self.text_embedding(text_inputs)
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
|
||||||
|
|
||||||
conds = []
|
conds = []
|
||||||
for k in range(cond_inputs.shape[1]):
|
for k in range(cond_inputs.shape[1]):
|
||||||
|
@ -65,10 +61,8 @@ class GptTtsHf(nn.Module):
|
||||||
while len(conds) < self.max_conditioning_inputs:
|
while len(conds) < self.max_conditioning_inputs:
|
||||||
conds.append(conds[-1])
|
conds.append(conds[-1])
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
conds = conds + self.conditioning_embedding
|
|
||||||
|
|
||||||
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_inputs.device))
|
|
||||||
|
|
||||||
emb = torch.cat([text_emb, conds, mel_emb], dim=1)
|
emb = torch.cat([text_emb, conds, mel_emb], dim=1)
|
||||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||||
|
@ -117,7 +111,6 @@ class GptTtsHf(nn.Module):
|
||||||
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
|
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
||||||
text_emb = self.text_embedding(text_inputs)
|
text_emb = self.text_embedding(text_inputs)
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
|
||||||
|
|
||||||
conds = []
|
conds = []
|
||||||
for k in range(cond_inputs.shape[1]):
|
for k in range(cond_inputs.shape[1]):
|
||||||
|
@ -125,7 +118,6 @@ class GptTtsHf(nn.Module):
|
||||||
while len(conds) < self.max_conditioning_inputs:
|
while len(conds) < self.max_conditioning_inputs:
|
||||||
conds.append(conds[-1])
|
conds.append(conds[-1])
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
conds = conds + self.conditioning_embedding
|
|
||||||
|
|
||||||
emb = torch.cat([text_emb, conds], dim=1)
|
emb = torch.cat([text_emb, conds], dim=1)
|
||||||
self.inference_model.store_mel_emb(emb)
|
self.inference_model.store_mel_emb(emb)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user