diff --git a/codes/models/audio/music/instrument_quantizer.py b/codes/models/audio/music/instrument_quantizer.py index ecbfd44d..00ea462d 100644 --- a/codes/models/audio/music/instrument_quantizer.py +++ b/codes/models/audio/music/instrument_quantizer.py @@ -28,7 +28,6 @@ class SelfClassifyingHead(nn.Module): h = self.dec(x) o = self.to_output(h[:, -1]) q, c, _ = self.quantizer(o, used_codes) - q = torch.sigmoid(q) return q, c def forward(self, x, target): @@ -37,10 +36,13 @@ class SelfClassifyingHead(nn.Module): outputs = [] results = [] codes = [] + q_reg = 0 for i in range(self.seq_len): q, c = checkpoint(functools.partial(self.do_ar_step, used_codes=codes), torch.stack(stack, dim=1)) + q_reg = q_reg + (q ** 2).mean() + s = torch.sigmoid(q) - outputs.append(q) + outputs.append(s) output = torch.stack(outputs, dim=1).sum(1) # If the addition would strictly make the result worse, set it to 0. Sometimes. @@ -49,13 +51,13 @@ class SelfClassifyingHead(nn.Module): probabilistic_worsen = torch.rand_like(worsen) * worsen > .5 output = output * probabilistic_worsen.unsqueeze(-1) # This is non-differentiable, but still deterministic. c[probabilistic_worsen] = -1 # Code of -1 means the code was unused. - q = q * probabilistic_worsen.unsqueeze(-1) - outputs[-1] = q + s = s * probabilistic_worsen.unsqueeze(-1) + outputs[-1] = s codes.append(c) - stack.append(self.to_decoder(q)) + stack.append(self.to_decoder(s)) results.append(output) - return results, torch.cat(codes, dim=0) + return results, torch.cat(codes, dim=0), q_reg / self.seq_len class VectorResBlock(nn.Module): @@ -112,13 +114,13 @@ class InstrumentQuantizer(nn.Module): for lyr in self.encoder: h = lyr(h) - reconstructions, codes = self.heads(h, f) + reconstructions, codes, q_reg = self.heads(h, f) reconstruction_losses = torch.stack([F.mse_loss(r.reshape(b, s, c), px) for r in reconstructions]) r_follow = torch.arange(1, reconstruction_losses.shape[0]+1, device=x.device) reconstruction_losses = (reconstruction_losses * r_follow / r_follow.shape[0]) self.log_codes(codes) - return reconstruction_losses + return reconstruction_losses, q_reg def log_codes(self, codes): if self.internal_step % 5 == 0: