This commit is contained in:
sunyt32 2023-09-28 17:05:53 +00:00
parent fd8234c2ac
commit 59fc5f7d3d

View File

@ -63,7 +63,7 @@ class MultiScaleRetention(nn.Module):
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False))
self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.norm_eps, elementwise_affine=False))
self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
self.reset_parameters()
def reset_parameters(self):