propagate diversity loss
This commit is contained in:
parent
4c6ef42b38
commit
de54be5570
|
@ -204,16 +204,16 @@ class MusicQuantizer(nn.Module):
|
|||
h = self.encoder(h)
|
||||
h = self.enc_norm(h.permute(0,2,1))
|
||||
codevectors, perplexity, codes = self.quantizer(h, return_probs=True)
|
||||
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
|
||||
self.log_codes(codes)
|
||||
h = self.decoder(codevectors.permute(0,2,1))
|
||||
if return_decoder_latent:
|
||||
return h
|
||||
return h, diversity
|
||||
|
||||
reconstructed = self.up(h)
|
||||
reconstructed = reconstructed[:, :, :orig_mel.shape[-1]]
|
||||
|
||||
mse = F.mse_loss(reconstructed, orig_mel)
|
||||
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
|
||||
return mse, diversity
|
||||
|
||||
def log_codes(self, codes):
|
||||
|
|
|
@ -205,7 +205,6 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
self.internal_step = 0
|
||||
self.freeze_quantizer_until = freeze_quantizer_until
|
||||
self.diff = TransformerDiffusion(**kwargs)
|
||||
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
||||
self.m2v = MusicQuantizer(inp_channels=256, inner_dim=2048, codevector_dim=1024)
|
||||
self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature
|
||||
del self.m2v.up
|
||||
|
@ -221,7 +220,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
|
||||
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
||||
with torch.set_grad_enabled(quant_grad_enabled):
|
||||
proj = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1)
|
||||
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1)
|
||||
|
||||
# Make sure this does not cause issues in DDP by explicitly using the parameters for nothing.
|
||||
if not quant_grad_enabled:
|
||||
|
@ -229,9 +228,10 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
for p in self.m2v.parameters():
|
||||
unused = unused + p.mean() * 0
|
||||
proj = proj + unused
|
||||
diversity_loss = diversity_loss * 0
|
||||
|
||||
return self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input,
|
||||
conditioning_free=conditioning_free)
|
||||
conditioning_free=conditioning_free), diversity_loss
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
if self.m2v.total_codes > 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user