forked from mrq/DL-Art-School
better bounds
This commit is contained in:
parent
fe9ea4e01a
commit
dc535b5358
|
@ -288,6 +288,7 @@ class GptAsrHf2(nn.Module):
|
||||||
mel_len = 0
|
mel_len = 0
|
||||||
else:
|
else:
|
||||||
mel_emb = self.mel_encoder(mel_inputs)
|
mel_emb = self.mel_encoder(mel_inputs)
|
||||||
|
assert mel_emb.shape[1] <= self.max_mel_frames
|
||||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||||
|
@ -302,6 +303,9 @@ class GptAsrHf2(nn.Module):
|
||||||
return text_logits
|
return text_logits
|
||||||
|
|
||||||
def forward(self, mel_inputs, text_inputs, return_attentions=False):
|
def forward(self, mel_inputs, text_inputs, return_attentions=False):
|
||||||
|
assert text_inputs.shape[1] <= self.max_symbols_per_phrase
|
||||||
|
assert text_inputs.max() <= self.number_text_tokens
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
||||||
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
||||||
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||||
|
@ -313,9 +317,9 @@ class GptAsrHf2(nn.Module):
|
||||||
return loss_text.mean(), text_logits
|
return loss_text.mean(), text_logits
|
||||||
|
|
||||||
def text_only(self, text_inputs):
|
def text_only(self, text_inputs):
|
||||||
if text_inputs.shape[1] > self.max_symbols_per_phrase:
|
assert text_inputs.shape[1] <= self.max_symbols_per_phrase
|
||||||
print(f"Embedding error, provided text_inputs with shape {text_inputs.shape}, but max is {self.max_symbols_per_phrase}. Automatically correcting by truncating symbols.")
|
assert text_inputs.max() <= self.number_text_tokens
|
||||||
text_inputs = text_inputs[:, :self.max_symbols_per_phrase]
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
||||||
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
||||||
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + \
|
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + \
|
||||||
|
|
Loading…
Reference in New Issue
Block a user