diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat.py b/codes/models/audio/tts/unet_diffusion_tts_flat.py index 699585e0..72e18b11 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat.py @@ -161,7 +161,14 @@ class DiffusionTtsFlat(nn.Module): code_emb = self.unconditioned_embedding.repeat(conditioning_input.shape[0], 1, 1) else: unused_params.append(self.unconditioned_embedding) - cond_emb = self.contextual_embedder(conditioning_input) + + speech_conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + cond_emb = conds.mean(dim=1) + if len(cond_emb.shape) == 3: # Just take the first element. cond_emb = cond_emb[:, :, 0] if is_latent(aligned_conditioning): @@ -224,7 +231,7 @@ if __name__ == '__main__': clip = torch.randn(2, 100, 400) aligned_latent = torch.randn(2,388,512) aligned_sequence = torch.randint(0,8192,(2,388)) - cond = torch.randn(2, 100, 400) + cond = torch.randn(2, 2, 100, 400) ts = torch.LongTensor([600, 600]) model = DiffusionTtsFlat(512, layer_drop=.3) # Test with latent aligned conditioning