diff --git a/codes/models/audio/music/unet_diffusion_waveform_gen2.py b/codes/models/audio/music/unet_diffusion_waveform_gen2.py index cd7eb5d2..6614d728 100644 --- a/codes/models/audio/music/unet_diffusion_waveform_gen2.py +++ b/codes/models/audio/music/unet_diffusion_waveform_gen2.py @@ -376,7 +376,7 @@ class DiffusionTts(nn.Module): aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1]))) return x, aligned_conditioning - def forward(self, x, timesteps, conditioning, conditioning_free=False): + def forward(self, x, timesteps, conditioning, return_surrogate=True, conditioning_free=False): """ Apply the model to an input batch. @@ -433,7 +433,10 @@ class DiffusionTts(nn.Module): extraneous_addition = extraneous_addition + p.mean() out = out + extraneous_addition * 0 - return out[:, :, :orig_x_shape], surrogate[:, :, :orig_x_shape] + if return_surrogate: + return out[:, :, :orig_x_shape], surrogate[:, :, :orig_x_shape] + else: + return out[:, :, :orig_x_shape] @register_model