diff --git a/torchscale/component/multiscale_retention.py b/torchscale/component/multiscale_retention.py index 02d3cd7..bf3a23b 100644 --- a/torchscale/component/multiscale_retention.py +++ b/torchscale/component/multiscale_retention.py @@ -67,7 +67,7 @@ class MultiScaleRetention(nn.Module): self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim * self.factor, embed_dim, bias=True)) - self.group_norm = MultiwayWrapper(args, LayerNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False)) + self.group_norm = MultiwayWrapper(args, LayerNorm(self.head_dim, eps=1e-6, elementwise_affine=False)) self.reset_parameters() def reset_parameters(self):