forked from mrq/DL-Art-School
Add code validation to autoregressive_codegen
This commit is contained in:
parent
99de63a922
commit
cdd12ff46c
|
@ -201,6 +201,8 @@ class AutoregressiveCodegen(nn.Module):
|
|||
|
||||
self.START_TOKEN=8192
|
||||
self.STOP_TOKEN=8193
|
||||
self.max_text_token_id = num_text_tokens
|
||||
self.max_mel_token_id = num_mel_tokens
|
||||
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
|
||||
self.encoder = CheckpointedXTransformerWrapper(
|
||||
num_tokens=num_text_tokens,
|
||||
|
@ -243,6 +245,9 @@ class AutoregressiveCodegen(nn.Module):
|
|||
}
|
||||
|
||||
def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True):
|
||||
assert text_codes.max() < self.max_text_token_id and text_codes.min() >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}'
|
||||
assert mel_codes.max() < self.max_mel_token_id and mel_codes.min() >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}'
|
||||
|
||||
# Format mel_codes with a stop token on the end.
|
||||
mel_lengths = wav_lengths // 1024 + 1
|
||||
for b in range(mel_codes.shape[0]):
|
||||
|
|
Loading…
Reference in New Issue
Block a user