added arg to change RelPos's base
This commit is contained in:
parent
881d03079d
commit
ce77afe916
|
@ -258,6 +258,9 @@ class RetNetConfig(object):
|
||||||
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
||||||
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
||||||
|
|
||||||
|
# RetNet's RelPos base
|
||||||
|
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
|
||||||
|
|
||||||
if self.deepnorm:
|
if self.deepnorm:
|
||||||
self.decoder_normalize_before = False
|
self.decoder_normalize_before = False
|
||||||
self.subln = False
|
self.subln = False
|
||||||
|
|
|
@ -22,7 +22,7 @@ from torchscale.component.rms_norm import RMSNorm
|
||||||
class RetNetRelPos(nn.Module):
|
class RetNetRelPos(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
angle = 1.0 / (10000 ** torch.linspace(0, 1, args.decoder_embed_dim // args.decoder_retention_heads // 2))
|
angle = 1.0 / (args.rotary_embedding_base ** torch.linspace(0, 1, args.decoder_embed_dim // args.decoder_retention_heads // 2))
|
||||||
angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
|
angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
|
||||||
decay = torch.log(1 - 2 ** (-5 - torch.arange(args.decoder_retention_heads, dtype=torch.float)))
|
decay = torch.log(1 - 2 ** (-5 - torch.arange(args.decoder_retention_heads, dtype=torch.float)))
|
||||||
self.register_buffer("angle", angle)
|
self.register_buffer("angle", angle)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user