diff --git a/codes/models/audio/music/gpt_music.py b/codes/models/audio/music/gpt_music.py index ac05f1dd..3b213974 100644 --- a/codes/models/audio/music/gpt_music.py +++ b/codes/models/audio/music/gpt_music.py @@ -27,7 +27,7 @@ class ConditioningEncoder(nn.Module): def forward(self, x): h = checkpoint(self.init, x) - h = self.attn(h + h = self.attn(h) return h.mean(dim=2)