diff --git a/codes/models/audio/music/transformer_diffusion5.py b/codes/models/audio/music/transformer_diffusion5.py index 3198de1a..fc938086 100644 --- a/codes/models/audio/music/transformer_diffusion5.py +++ b/codes/models/audio/music/transformer_diffusion5.py @@ -214,6 +214,10 @@ class TransformerDiffusionWithQuantizer(nn.Module): self.total_codes = 0 del self.m2v.m2v.encoder + del self.m2v.reconstruction_net + del self.m2v.m2v.projector.projection + del self.m2v.project_hid + del self.m2v.project_q def update_for_step(self, step, *args): self.internal_step = step @@ -224,7 +228,7 @@ class TransformerDiffusionWithQuantizer(nn.Module): def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False): proj = self.m2v.m2v.input_blocks(truth_mel).permute(0,2,1) - _, proj = self.m2v.m2v.projector(proj) + proj = self.m2v.m2v.projector.layer_norm(proj) vectors, _, probs = self.m2v.quantizer(proj, return_probs=True) self.log_codes(probs) return self.diff(x, timesteps, codes=vectors, conditioning_input=conditioning_input, conditioning_free=conditioning_free)