From de54be5570c06223d9ce90576034e59559ed8126 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 1 Jun 2022 14:18:50 -0600 Subject: [PATCH] propagate diversity loss --- codes/models/audio/music/music_quantizer.py | 4 ++-- codes/models/audio/music/transformer_diffusion7.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/codes/models/audio/music/music_quantizer.py b/codes/models/audio/music/music_quantizer.py index d55933c3..3f7c59af 100644 --- a/codes/models/audio/music/music_quantizer.py +++ b/codes/models/audio/music/music_quantizer.py @@ -204,16 +204,16 @@ class MusicQuantizer(nn.Module): h = self.encoder(h) h = self.enc_norm(h.permute(0,2,1)) codevectors, perplexity, codes = self.quantizer(h, return_probs=True) + diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors self.log_codes(codes) h = self.decoder(codevectors.permute(0,2,1)) if return_decoder_latent: - return h + return h, diversity reconstructed = self.up(h) reconstructed = reconstructed[:, :, :orig_mel.shape[-1]] mse = F.mse_loss(reconstructed, orig_mel) - diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors return mse, diversity def log_codes(self, codes): diff --git a/codes/models/audio/music/transformer_diffusion7.py b/codes/models/audio/music/transformer_diffusion7.py index 469c15cb..bc29014b 100644 --- a/codes/models/audio/music/transformer_diffusion7.py +++ b/codes/models/audio/music/transformer_diffusion7.py @@ -205,7 +205,6 @@ class TransformerDiffusionWithQuantizer(nn.Module): self.internal_step = 0 self.freeze_quantizer_until = freeze_quantizer_until self.diff = TransformerDiffusion(**kwargs) - from models.audio.mel2vec import ContrastiveTrainingWrapper self.m2v = MusicQuantizer(inp_channels=256, inner_dim=2048, codevector_dim=1024) self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature del self.m2v.up @@ -221,7 +220,7 @@ class TransformerDiffusionWithQuantizer(nn.Module): def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False): quant_grad_enabled = self.internal_step > self.freeze_quantizer_until with torch.set_grad_enabled(quant_grad_enabled): - proj = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1) + proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1) # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. if not quant_grad_enabled: @@ -229,9 +228,10 @@ class TransformerDiffusionWithQuantizer(nn.Module): for p in self.m2v.parameters(): unused = unused + p.mean() * 0 proj = proj + unused + diversity_loss = diversity_loss * 0 return self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, - conditioning_free=conditioning_free) + conditioning_free=conditioning_free), diversity_loss def get_debug_values(self, step, __): if self.m2v.total_codes > 0: