diff --git a/codes/models/audio/music/instrument_quantizer.py b/codes/models/audio/music/instrument_quantizer.py index ff95ef84..640b663b 100644 --- a/codes/models/audio/music/instrument_quantizer.py +++ b/codes/models/audio/music/instrument_quantizer.py @@ -19,14 +19,22 @@ class SelfClassifyingHead(nn.Module): self.temperature = init_temperature self.dec = Decoder(dim=dim, depth=head_depth, heads=4, ff_dropout=dropout, ff_mult=2, attn_dropout=dropout, use_rmsnorm=True, ff_glu=True, do_checkpointing=False) - self.quantizer = VectorQuantize(dim, classes, codebook_dim=32, use_cosine_sim=True, threshold_ema_dead_code=2, + self.quantizer = VectorQuantize(out_dim, classes, codebook_dim=16, use_cosine_sim=False, threshold_ema_dead_code=2, sample_codebook_temp=init_temperature) self.to_output = nn.Linear(dim, out_dim) + self.to_decoder = nn.Linear(out_dim, dim) + self.scale = nn.Linear(dim, 1, bias=False) + self.scale.weight.data.zero_() def do_ar_step(self, x, used_codes): + MIN = -12 + h = self.dec(x) - h, c, _ = self.quantizer(h[:, -1], used_codes) - return h, c + o = self.to_output(h[:, -1]) + scale = (self.scale(h[:, -1]) + 1) + q, c, _ = self.quantizer(o, used_codes) + q = F.relu(q * scale) + MIN + return q, c def forward(self, x): with torch.no_grad(): @@ -39,12 +47,12 @@ class SelfClassifyingHead(nn.Module): results = [] codes = [] for i in range(self.seq_len): - h, c = checkpoint(functools.partial(self.do_ar_step, used_codes=codes), torch.stack(stack, dim=1)) + q, c = checkpoint(functools.partial(self.do_ar_step, used_codes=codes), torch.stack(stack, dim=1)) c_mask = c c_mask[c==0] = -1 # Mask this out because we want code=0 to be capable of being repeated. - codes.append(c) - stack.append(h.detach()) # Detach here to avoid piling up gradients from autoregression. We really just want the gradients to flow to the selected class embeddings and the selector for those classes. - outputs.append(self.to_output(h)) + codes.append(c_mask) + stack.append(self.to_decoder(q)) + outputs.append(q) results.append(torch.stack(outputs, dim=1).sum(1)) return results, torch.cat(codes, dim=0) @@ -81,7 +89,6 @@ class InstrumentQuantizer(nn.Module): self.op_dim = op_dim self.proj = nn.Linear(op_dim, dim) self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)]) - self.final_bn = nn.BatchNorm1d(dim) self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp) self.min_gumbel_temperature = min_temp self.max_gumbel_temperature = max_temp @@ -99,7 +106,6 @@ class InstrumentQuantizer(nn.Module): h = self.proj(f) for lyr in self.encoder: h = lyr(h) - h = self.final_bn(h.unsqueeze(-1)).squeeze(-1) reconstructions, codes = self.heads(h) reconstruction_losses = torch.stack([F.mse_loss(r.reshape(b, s, c), px) for r in reconstructions])