auto-fix text_inputs too big
This commit is contained in:
parent
35abefd038
commit
fe9ea4e01a
|
@ -313,6 +313,9 @@ class GptAsrHf2(nn.Module):
|
|||
return loss_text.mean(), text_logits
|
||||
|
||||
def text_only(self, text_inputs):
|
||||
if text_inputs.shape[1] > self.max_symbols_per_phrase:
|
||||
print(f"Embedding error, provided text_inputs with shape {text_inputs.shape}, but max is {self.max_symbols_per_phrase}. Automatically correcting by truncating symbols.")
|
||||
text_inputs = text_inputs[:, :self.max_symbols_per_phrase]
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
||||
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
||||
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + \
|
||||
|
|
Loading…
Reference in New Issue
Block a user