forked from mrq/DL-Art-School
propagate diversity loss
This commit is contained in:
parent
4c6ef42b38
commit
de54be5570
|
@ -204,16 +204,16 @@ class MusicQuantizer(nn.Module):
|
||||||
h = self.encoder(h)
|
h = self.encoder(h)
|
||||||
h = self.enc_norm(h.permute(0,2,1))
|
h = self.enc_norm(h.permute(0,2,1))
|
||||||
codevectors, perplexity, codes = self.quantizer(h, return_probs=True)
|
codevectors, perplexity, codes = self.quantizer(h, return_probs=True)
|
||||||
|
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
|
||||||
self.log_codes(codes)
|
self.log_codes(codes)
|
||||||
h = self.decoder(codevectors.permute(0,2,1))
|
h = self.decoder(codevectors.permute(0,2,1))
|
||||||
if return_decoder_latent:
|
if return_decoder_latent:
|
||||||
return h
|
return h, diversity
|
||||||
|
|
||||||
reconstructed = self.up(h)
|
reconstructed = self.up(h)
|
||||||
reconstructed = reconstructed[:, :, :orig_mel.shape[-1]]
|
reconstructed = reconstructed[:, :, :orig_mel.shape[-1]]
|
||||||
|
|
||||||
mse = F.mse_loss(reconstructed, orig_mel)
|
mse = F.mse_loss(reconstructed, orig_mel)
|
||||||
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
|
|
||||||
return mse, diversity
|
return mse, diversity
|
||||||
|
|
||||||
def log_codes(self, codes):
|
def log_codes(self, codes):
|
||||||
|
|
|
@ -205,7 +205,6 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
self.internal_step = 0
|
self.internal_step = 0
|
||||||
self.freeze_quantizer_until = freeze_quantizer_until
|
self.freeze_quantizer_until = freeze_quantizer_until
|
||||||
self.diff = TransformerDiffusion(**kwargs)
|
self.diff = TransformerDiffusion(**kwargs)
|
||||||
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
|
||||||
self.m2v = MusicQuantizer(inp_channels=256, inner_dim=2048, codevector_dim=1024)
|
self.m2v = MusicQuantizer(inp_channels=256, inner_dim=2048, codevector_dim=1024)
|
||||||
self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature
|
self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature
|
||||||
del self.m2v.up
|
del self.m2v.up
|
||||||
|
@ -221,7 +220,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
|
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
|
||||||
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
||||||
with torch.set_grad_enabled(quant_grad_enabled):
|
with torch.set_grad_enabled(quant_grad_enabled):
|
||||||
proj = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1)
|
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1)
|
||||||
|
|
||||||
# Make sure this does not cause issues in DDP by explicitly using the parameters for nothing.
|
# Make sure this does not cause issues in DDP by explicitly using the parameters for nothing.
|
||||||
if not quant_grad_enabled:
|
if not quant_grad_enabled:
|
||||||
|
@ -229,9 +228,10 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
for p in self.m2v.parameters():
|
for p in self.m2v.parameters():
|
||||||
unused = unused + p.mean() * 0
|
unused = unused + p.mean() * 0
|
||||||
proj = proj + unused
|
proj = proj + unused
|
||||||
|
diversity_loss = diversity_loss * 0
|
||||||
|
|
||||||
return self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input,
|
return self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input,
|
||||||
conditioning_free=conditioning_free)
|
conditioning_free=conditioning_free), diversity_loss
|
||||||
|
|
||||||
def get_debug_values(self, step, __):
|
def get_debug_values(self, step, __):
|
||||||
if self.m2v.total_codes > 0:
|
if self.m2v.total_codes > 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user