From a6181a489b482b8b9a35835f808d3ee38d791818 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 26 Mar 2022 22:49:14 -0600 Subject: [PATCH] Fix loss gapping caused by poor gradients into mel_pred --- codes/models/audio/tts/unet_diffusion_tts_flat0.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat0.py b/codes/models/audio/tts/unet_diffusion_tts_flat0.py index 42574c72..05f8847e 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat0.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat0.py @@ -223,6 +223,7 @@ class DiffusionTtsFlat(nn.Module): code_emb = self.code_converter(code_emb) unused_params.extend(list(self.latent_converter.parameters())) code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) + unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), @@ -231,6 +232,8 @@ class DiffusionTtsFlat(nn.Module): code_emb) expanded_code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest') mel_pred = self.mel_head(expanded_code_emb) + # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss. + mel_pred = mel_pred * unconditioned_batches.logical_not() # Everything after this comment is timestep dependent. time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) @@ -272,7 +275,7 @@ if __name__ == '__main__': aligned_sequence = torch.randint(0,8192,(2,388)) cond = torch.randn(2, 100, 400) ts = torch.LongTensor([600, 600]) - model = DiffusionTtsFlat(512, layer_drop=.3) + model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5) # Test with latent aligned conditioning #o = model(clip, ts, aligned_latent, cond) # Test with sequence aligned conditioning