From d4a62986587d2e33abf78e73bb08907c12b4ad32 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 1 Jan 2022 14:25:27 -0700 Subject: [PATCH] more debugging --- codes/data/__init__.py | 2 +- codes/models/gpt_voice/gpt_asr_hf2.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 53f3c309..1ad2762e 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -89,7 +89,7 @@ def create_dataset(dataset_opt, return_collate=False): elif mode == 'grand_conjoined_voice': from data.audio.grand_conjoined_dataset import GrandConjoinedDataset as D from data.zero_pad_dict_collate import ZeroPadDictCollate as C - if opt_get(dataset_opt, ['needs_collate'], True): + if opt_get(dataset_opt, ['needs_collate'], False): collate = C() else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index d315cbae..ea4b5303 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -288,7 +288,7 @@ class GptAsrHf2(nn.Module): mel_len = 0 else: mel_emb = self.mel_encoder(mel_inputs) - assert mel_emb.shape[1] <= self.max_mel_frames + assert mel_emb.shape[1] <= self.max_mel_frames, f'{mel_emb.shape[1]} > {self.max_mel_frames}' mel_emb = mel_emb.permute(0,2,1).contiguous() mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) emb = torch.cat([mel_emb, text_emb], dim=1) @@ -303,8 +303,8 @@ class GptAsrHf2(nn.Module): return text_logits def forward(self, mel_inputs, text_inputs, return_attentions=False): - assert text_inputs.shape[1] <= self.max_symbols_per_phrase - assert text_inputs.max() <= self.number_text_tokens + assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1]) + assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max()) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token) text_emb = self.gpt.get_input_embeddings()(text_inputs) + \ @@ -317,8 +317,8 @@ class GptAsrHf2(nn.Module): return loss_text.mean(), text_logits def text_only(self, text_inputs): - assert text_inputs.shape[1] <= self.max_symbols_per_phrase - assert text_inputs.max() <= self.number_text_tokens + assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1]) + assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max()) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token) text_emb = self.gpt.get_input_embeddings()(text_inputs) + \