make code pred returns optional

This commit is contained in:
James Betker 2022-03-26 08:33:30 -06:00
parent 2a29a71c37
commit 6909f196b4

View File

@ -187,7 +187,7 @@ class DiffusionTtsFlat(nn.Module):
}
return groups
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False):
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False, return_code_pred=False):
"""
Apply the model to an input batch.
@ -256,7 +256,9 @@ class DiffusionTtsFlat(nn.Module):
extraneous_addition = extraneous_addition + p.mean()
out = out + extraneous_addition * 0
return out, mel_pred
if return_code_pred:
return out, mel_pred
return out
@register_model