Fix loss gapping caused by poor gradients into mel_pred

This commit is contained in:
James Betker 2022-03-26 22:49:14 -06:00
parent 0070867d0f
commit a6181a489b

View File

@ -223,6 +223,7 @@ class DiffusionTtsFlat(nn.Module):
code_emb = self.code_converter(code_emb)
unused_params.extend(list(self.latent_converter.parameters()))
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
@ -231,6 +232,8 @@ class DiffusionTtsFlat(nn.Module):
code_emb)
expanded_code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest')
mel_pred = self.mel_head(expanded_code_emb)
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
mel_pred = mel_pred * unconditioned_batches.logical_not()
# Everything after this comment is timestep dependent.
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
@ -272,7 +275,7 @@ if __name__ == '__main__':
aligned_sequence = torch.randint(0,8192,(2,388))
cond = torch.randn(2, 100, 400)
ts = torch.LongTensor([600, 600])
model = DiffusionTtsFlat(512, layer_drop=.3)
model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5)
# Test with latent aligned conditioning
#o = model(clip, ts, aligned_latent, cond)
# Test with sequence aligned conditioning