added arg to change RelPos's base

This commit is contained in:
mrq 2023-09-19 19:18:49 -05:00
parent 881d03079d
commit ce77afe916
2 changed files with 4 additions and 1 deletions

View File

@ -258,6 +258,9 @@ class RetNetConfig(object):
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
# RetNet's RelPos base
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
if self.deepnorm: if self.deepnorm:
self.decoder_normalize_before = False self.decoder_normalize_before = False
self.subln = False self.subln = False

View File

@ -22,7 +22,7 @@ from torchscale.component.rms_norm import RMSNorm
class RetNetRelPos(nn.Module): class RetNetRelPos(nn.Module):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
angle = 1.0 / (10000 ** torch.linspace(0, 1, args.decoder_embed_dim // args.decoder_retention_heads // 2)) angle = 1.0 / (args.rotary_embedding_base ** torch.linspace(0, 1, args.decoder_embed_dim // args.decoder_retention_heads // 2))
angle = angle.unsqueeze(-1).repeat(1, 2).flatten() angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
decay = torch.log(1 - 2 ** (-5 - torch.arange(args.decoder_retention_heads, dtype=torch.float))) decay = torch.log(1 - 2 ** (-5 - torch.arange(args.decoder_retention_heads, dtype=torch.float)))
self.register_buffer("angle", angle) self.register_buffer("angle", angle)