diff --git a/codes/models/audio/music/transformer_diffusion9.py b/codes/models/audio/music/transformer_diffusion9.py index b2068d28..40d24665 100644 --- a/codes/models/audio/music/transformer_diffusion9.py +++ b/codes/models/audio/music/transformer_diffusion9.py @@ -43,8 +43,7 @@ class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): class DietAttentionBlock(TimestepBlock): def __init__(self, in_dim, dim, heads, dropout): super().__init__() - self.proj = nn.Linear(in_dim, dim) - self.proj.bias.data.zero_() + self.proj = nn.Linear(in_dim, dim, bias=False) self.rms_scale_norm = RMSScaleShiftNorm(dim, bias=False) self.attn = Attention(dim, heads=heads, dim_head=dim//heads, causal=False, dropout=dropout) self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True)