one more adjustment

This commit is contained in:
James Betker 2022-06-11 08:01:46 -06:00
parent df0cdf1a4f
commit 41170f97e9

View File

@ -43,8 +43,7 @@ class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock):
class DietAttentionBlock(TimestepBlock):
def __init__(self, in_dim, dim, heads, dropout):
super().__init__()
self.proj = nn.Linear(in_dim, dim)
self.proj.bias.data.zero_()
self.proj = nn.Linear(in_dim, dim, bias=False)
self.rms_scale_norm = RMSScaleShiftNorm(dim, bias=False)
self.attn = Attention(dim, heads=heads, dim_head=dim//heads, causal=False, dropout=dropout)
self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True)