Better dimensional asserting

This commit is contained in:
James Betker 2021-12-25 23:18:25 -07:00
parent e959541494
commit 8acf3b3097

View File

@ -143,6 +143,10 @@ class UnifiedGptVoice(nn.Module):
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
"""
assert self.max_symbols_per_phrase >= mel_inputs.shape[1]
assert self.max_symbols_per_phrase >= text_inputs.shape[1]
assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1]
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
@ -168,6 +172,8 @@ class UnifiedGptVoice(nn.Module):
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
"""
assert self.max_symbols_per_phrase >= text_inputs.shape[1]
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
@ -181,6 +187,8 @@ class UnifiedGptVoice(nn.Module):
"""
Performs autoregressive modeling on only speech data.
"""
assert self.max_symbols_per_phrase >= mel_inputs.shape[1]
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)