Better dimensional asserting

This commit is contained in:
James Betker 2021-12-25 23:18:25 -07:00
parent e959541494
commit 8acf3b3097

View File

@ -143,6 +143,10 @@ class UnifiedGptVoice(nn.Module):
mel_inputs: long tensor, (b,m) mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,) wav_lengths: long tensor, (b,)
""" """
assert self.max_symbols_per_phrase >= mel_inputs.shape[1]
assert self.max_symbols_per_phrase >= text_inputs.shape[1]
assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1]
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths) mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
@ -168,6 +172,8 @@ class UnifiedGptVoice(nn.Module):
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided). model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
""" """
assert self.max_symbols_per_phrase >= text_inputs.shape[1]
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
@ -181,6 +187,8 @@ class UnifiedGptVoice(nn.Module):
""" """
Performs autoregressive modeling on only speech data. Performs autoregressive modeling on only speech data.
""" """
assert self.max_symbols_per_phrase >= mel_inputs.shape[1]
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths) mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)