added arg to change RelPos's base
This commit is contained in:
parent
881d03079d
commit
ce77afe916
|
@ -258,6 +258,9 @@ class RetNetConfig(object):
|
|||
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
||||
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
||||
|
||||
# RetNet's RelPos base
|
||||
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
|
||||
|
||||
if self.deepnorm:
|
||||
self.decoder_normalize_before = False
|
||||
self.subln = False
|
||||
|
|
|
@ -22,7 +22,7 @@ from torchscale.component.rms_norm import RMSNorm
|
|||
class RetNetRelPos(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
angle = 1.0 / (10000 ** torch.linspace(0, 1, args.decoder_embed_dim // args.decoder_retention_heads // 2))
|
||||
angle = 1.0 / (args.rotary_embedding_base ** torch.linspace(0, 1, args.decoder_embed_dim // args.decoder_retention_heads // 2))
|
||||
angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
|
||||
decay = torch.log(1 - 2 ** (-5 - torch.arange(args.decoder_retention_heads, dtype=torch.float)))
|
||||
self.register_buffer("angle", angle)
|
||||
|
|
Loading…
Reference in New Issue
Block a user