From 6909f196b404fdc7d683f9631ebefbf439af8d73 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 26 Mar 2022 08:33:30 -0600 Subject: [PATCH] make code pred returns optional --- codes/models/audio/tts/unet_diffusion_tts_flat0.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat0.py b/codes/models/audio/tts/unet_diffusion_tts_flat0.py index 14f59f00..42574c72 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat0.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat0.py @@ -187,7 +187,7 @@ class DiffusionTtsFlat(nn.Module): } return groups - def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False): + def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False, return_code_pred=False): """ Apply the model to an input batch. @@ -256,7 +256,9 @@ class DiffusionTtsFlat(nn.Module): extraneous_addition = extraneous_addition + p.mean() out = out + extraneous_addition * 0 - return out, mel_pred + if return_code_pred: + return out, mel_pred + return out @register_model