make code pred returns optional
This commit is contained in:
parent
2a29a71c37
commit
6909f196b4
|
@ -187,7 +187,7 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
}
|
}
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False):
|
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False, return_code_pred=False):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
|
|
||||||
|
@ -256,7 +256,9 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
extraneous_addition = extraneous_addition + p.mean()
|
extraneous_addition = extraneous_addition + p.mean()
|
||||||
out = out + extraneous_addition * 0
|
out = out + extraneous_addition * 0
|
||||||
|
|
||||||
|
if return_code_pred:
|
||||||
return out, mel_pred
|
return out, mel_pred
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
Loading…
Reference in New Issue
Block a user