propagate diversity loss

This commit is contained in:
James Betker 2022-06-01 14:18:50 -06:00
parent 4c6ef42b38
commit de54be5570
2 changed files with 5 additions and 5 deletions

View File

@ -204,16 +204,16 @@ class MusicQuantizer(nn.Module):
h = self.encoder(h)
h = self.enc_norm(h.permute(0,2,1))
codevectors, perplexity, codes = self.quantizer(h, return_probs=True)
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
self.log_codes(codes)
h = self.decoder(codevectors.permute(0,2,1))
if return_decoder_latent:
return h
return h, diversity
reconstructed = self.up(h)
reconstructed = reconstructed[:, :, :orig_mel.shape[-1]]
mse = F.mse_loss(reconstructed, orig_mel)
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
return mse, diversity
def log_codes(self, codes):

View File

@ -205,7 +205,6 @@ class TransformerDiffusionWithQuantizer(nn.Module):
self.internal_step = 0
self.freeze_quantizer_until = freeze_quantizer_until
self.diff = TransformerDiffusion(**kwargs)
from models.audio.mel2vec import ContrastiveTrainingWrapper
self.m2v = MusicQuantizer(inp_channels=256, inner_dim=2048, codevector_dim=1024)
self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature
del self.m2v.up
@ -221,7 +220,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
with torch.set_grad_enabled(quant_grad_enabled):
proj = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1)
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1)
# Make sure this does not cause issues in DDP by explicitly using the parameters for nothing.
if not quant_grad_enabled:
@ -229,9 +228,10 @@ class TransformerDiffusionWithQuantizer(nn.Module):
for p in self.m2v.parameters():
unused = unused + p.mean() * 0
proj = proj + unused
diversity_loss = diversity_loss * 0
return self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input,
conditioning_free=conditioning_free)
conditioning_free=conditioning_free), diversity_loss
def get_debug_values(self, step, __):
if self.m2v.total_codes > 0: