From 8623c51902570b58fa707717dde53ea3a1045b81 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 1 Apr 2022 16:11:34 -0600 Subject: [PATCH] fix bug --- codes/models/audio/tts/unet_diffusion_tts_flat0.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat0.py b/codes/models/audio/tts/unet_diffusion_tts_flat0.py index 44f18b2c..d46c4549 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat0.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat0.py @@ -200,7 +200,7 @@ class DiffusionTtsFlat(nn.Module): } return groups - def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred): + def timestep_independent(self, aligned_conditioning, conditioning_input, expected_seq_len, return_code_pred): # Shuffle aligned_latent to BxCxS format if is_latent(aligned_conditioning): aligned_conditioning = aligned_conditioning.permute(0, 2, 1) @@ -228,7 +228,7 @@ class DiffusionTtsFlat(nn.Module): device=code_emb.device) < self.unconditioned_percentage code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), code_emb) - expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest') + expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest') if not return_code_pred: return expanded_code_emb @@ -263,7 +263,7 @@ class DiffusionTtsFlat(nn.Module): if precomputed_aligned_embeddings is not None: code_emb = precomputed_aligned_embeddings else: - code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True) + code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True) unused_params.append(self.unconditioned_embedding) if is_latent(aligned_conditioning):