From dc535b5358fb829f4b575e07d068391e5eca1eec Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 1 Jan 2022 14:05:22 -0700 Subject: [PATCH] better bounds --- codes/models/gpt_voice/gpt_asr_hf2.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index 69ff2657..d315cbae 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -288,6 +288,7 @@ class GptAsrHf2(nn.Module): mel_len = 0 else: mel_emb = self.mel_encoder(mel_inputs) + assert mel_emb.shape[1] <= self.max_mel_frames mel_emb = mel_emb.permute(0,2,1).contiguous() mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) emb = torch.cat([mel_emb, text_emb], dim=1) @@ -302,6 +303,9 @@ class GptAsrHf2(nn.Module): return text_logits def forward(self, mel_inputs, text_inputs, return_attentions=False): + assert text_inputs.shape[1] <= self.max_symbols_per_phrase + assert text_inputs.max() <= self.number_text_tokens + text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token) text_emb = self.gpt.get_input_embeddings()(text_inputs) + \ self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) @@ -313,9 +317,9 @@ class GptAsrHf2(nn.Module): return loss_text.mean(), text_logits def text_only(self, text_inputs): - if text_inputs.shape[1] > self.max_symbols_per_phrase: - print(f"Embedding error, provided text_inputs with shape {text_inputs.shape}, but max is {self.max_symbols_per_phrase}. Automatically correcting by truncating symbols.") - text_inputs = text_inputs[:, :self.max_symbols_per_phrase] + assert text_inputs.shape[1] <= self.max_symbols_per_phrase + assert text_inputs.max() <= self.number_text_tokens + text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token) text_emb = self.gpt.get_input_embeddings()(text_inputs) + \ self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + \