diff --git a/codes/models/audio/music/transformer_diffusion9.py b/codes/models/audio/music/transformer_diffusion9.py index 1de4a285..0ac5d7af 100644 --- a/codes/models/audio/music/transformer_diffusion9.py +++ b/codes/models/audio/music/transformer_diffusion9.py @@ -57,7 +57,7 @@ class DietAttentionBlock(TimestepBlock): h = torch.cat([ah, h], dim=-1) h = F.gelu(self.attnorm(h)) h = checkpoint(self.ff, h) - return h * self.exit_mult + return h * self.exit_mult + x class TransformerDiffusion(nn.Module):