forked from mrq/DL-Art-School
Fix loss gapping caused by poor gradients into mel_pred
This commit is contained in:
parent
0070867d0f
commit
a6181a489b
|
@ -223,6 +223,7 @@ class DiffusionTtsFlat(nn.Module):
|
|||
code_emb = self.code_converter(code_emb)
|
||||
unused_params.extend(list(self.latent_converter.parameters()))
|
||||
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||
|
@ -231,6 +232,8 @@ class DiffusionTtsFlat(nn.Module):
|
|||
code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest')
|
||||
mel_pred = self.mel_head(expanded_code_emb)
|
||||
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
|
||||
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
||||
|
||||
# Everything after this comment is timestep dependent.
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
@ -272,7 +275,7 @@ if __name__ == '__main__':
|
|||
aligned_sequence = torch.randint(0,8192,(2,388))
|
||||
cond = torch.randn(2, 100, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = DiffusionTtsFlat(512, layer_drop=.3)
|
||||
model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5)
|
||||
# Test with latent aligned conditioning
|
||||
#o = model(clip, ts, aligned_latent, cond)
|
||||
# Test with sequence aligned conditioning
|
||||
|
|
Loading…
Reference in New Issue
Block a user