From fe9ea4e01a21875a04de3266c0b14991c0c72f16 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 1 Jan 2022 13:25:47 -0700 Subject: [PATCH] auto-fix text_inputs too big --- codes/models/gpt_voice/gpt_asr_hf2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index 461f03b8..69ff2657 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -313,6 +313,9 @@ class GptAsrHf2(nn.Module): return loss_text.mean(), text_logits def text_only(self, text_inputs): + if text_inputs.shape[1] > self.max_symbols_per_phrase: + print(f"Embedding error, provided text_inputs with shape {text_inputs.shape}, but max is {self.max_symbols_per_phrase}. Automatically correcting by truncating symbols.") + text_inputs = text_inputs[:, :self.max_symbols_per_phrase] text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token) text_emb = self.gpt.get_input_embeddings()(text_inputs) + \ self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + \