forked from mrq/DL-Art-School
more debugging
This commit is contained in:
parent
d8111e0477
commit
d4a6298658
|
@ -89,7 +89,7 @@ def create_dataset(dataset_opt, return_collate=False):
|
|||
elif mode == 'grand_conjoined_voice':
|
||||
from data.audio.grand_conjoined_dataset import GrandConjoinedDataset as D
|
||||
from data.zero_pad_dict_collate import ZeroPadDictCollate as C
|
||||
if opt_get(dataset_opt, ['needs_collate'], True):
|
||||
if opt_get(dataset_opt, ['needs_collate'], False):
|
||||
collate = C()
|
||||
else:
|
||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||
|
|
|
@ -288,7 +288,7 @@ class GptAsrHf2(nn.Module):
|
|||
mel_len = 0
|
||||
else:
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
assert mel_emb.shape[1] <= self.max_mel_frames
|
||||
assert mel_emb.shape[1] <= self.max_mel_frames, f'{mel_emb.shape[1]} > {self.max_mel_frames}'
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
|
@ -303,8 +303,8 @@ class GptAsrHf2(nn.Module):
|
|||
return text_logits
|
||||
|
||||
def forward(self, mel_inputs, text_inputs, return_attentions=False):
|
||||
assert text_inputs.shape[1] <= self.max_symbols_per_phrase
|
||||
assert text_inputs.max() <= self.number_text_tokens
|
||||
assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1])
|
||||
assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max())
|
||||
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
||||
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
||||
|
@ -317,8 +317,8 @@ class GptAsrHf2(nn.Module):
|
|||
return loss_text.mean(), text_logits
|
||||
|
||||
def text_only(self, text_inputs):
|
||||
assert text_inputs.shape[1] <= self.max_symbols_per_phrase
|
||||
assert text_inputs.max() <= self.number_text_tokens
|
||||
assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1])
|
||||
assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max())
|
||||
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
||||
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
||||
|
|
Loading…
Reference in New Issue
Block a user