diff --git a/torchscale/architecture/config.py b/torchscale/architecture/config.py index 5898f39..b267dd8 100644 --- a/torchscale/architecture/config.py +++ b/torchscale/architecture/config.py @@ -258,6 +258,9 @@ class RetNetConfig(object): self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) + # RetNet's RelPos base + self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000) + if self.deepnorm: self.decoder_normalize_before = False self.subln = False diff --git a/torchscale/architecture/retnet.py b/torchscale/architecture/retnet.py index b29928c..83a368b 100644 --- a/torchscale/architecture/retnet.py +++ b/torchscale/architecture/retnet.py @@ -22,7 +22,7 @@ from torchscale.component.rms_norm import RMSNorm class RetNetRelPos(nn.Module): def __init__(self, args): super().__init__() - angle = 1.0 / (10000 ** torch.linspace(0, 1, args.decoder_embed_dim // args.decoder_retention_heads // 2)) + angle = 1.0 / (args.rotary_embedding_base ** torch.linspace(0, 1, args.decoder_embed_dim // args.decoder_retention_heads // 2)) angle = angle.unsqueeze(-1).repeat(1, 2).flatten() decay = torch.log(1 - 2 ** (-5 - torch.arange(args.decoder_retention_heads, dtype=torch.float))) self.register_buffer("angle", angle)