diff --git a/README.md b/README.md index 9d469cc..7deac03 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,9 @@ We also support the `Decoder` architecture and the `EncoderDecoder` architecture * enabled by *multiway=True*. * It provides a pool of Transformer's parameters used for different modalities. +- [Extrapolatable position embedding (Xpos)](https://arxiv.org/abs/2212.10554) + * enabled by *xpos_rel_pos=True*. + - [Relative position bias](https://arxiv.org/abs/1910.10683) * enabled by adjusting *rel_pos_buckets* and *max_rel_pos*. diff --git a/examples/fairseq/models/language_modeling.py b/examples/fairseq/models/language_modeling.py index 7ec7b33..71bf1a5 100644 --- a/examples/fairseq/models/language_modeling.py +++ b/examples/fairseq/models/language_modeling.py @@ -190,6 +190,12 @@ class LanguageConfig(FairseqDataclass): max_rel_pos: Optional[int] = field( default=0, ) + xpos_rel_pos: Optional[bool] = field( + default=False, + ) + xpos_scale_base: Optional[int] = field( + default=512, + ) @register_model("lm", dataclass=LanguageConfig) diff --git a/torchscale/architecture/config.py b/torchscale/architecture/config.py index 347ee23..6b8bb5c 100644 --- a/torchscale/architecture/config.py +++ b/torchscale/architecture/config.py @@ -49,6 +49,8 @@ class EncoderConfig(object): self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) self.fsdp = kwargs.pop("fsdp", False) self.ddp_rank = kwargs.pop("ddp_rank", 0) + self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) + self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) if self.deepnorm: self.encoder_normalize_before = False @@ -110,6 +112,8 @@ class DecoderConfig(object): self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) self.fsdp = kwargs.pop("fsdp", False) self.ddp_rank = kwargs.pop("ddp_rank", 0) + self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) + self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) if self.deepnorm: self.decoder_normalize_before = False @@ -178,6 +182,8 @@ class EncoderDecoderConfig(object): self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) self.fsdp = kwargs.pop("fsdp", False) self.ddp_rank = kwargs.pop("ddp_rank", 0) + self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) + self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) if self.deepnorm: self.encoder_normalize_before = False diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 819c6d3..3c67c28 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -9,6 +9,7 @@ from apex.normalization import FusedLayerNorm as LayerNorm from torch import nn from .multiway_network import MultiwayWrapper +from .xpos_relative_position import XPOS class MultiheadAttention(nn.Module): @@ -44,6 +45,11 @@ class MultiheadAttention(nn.Module): else None ) self.dropout_module = torch.nn.Dropout(dropout, inplace=True) + self.xpos = ( + XPOS(self.head_dim, args.xpos_scale_base) + if args.xpos_rel_pos and self.self_attention + else None + ) def reset_parameters(self): nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) @@ -99,6 +105,14 @@ class MultiheadAttention(nn.Module): ) src_len = k.size(1) + if self.xpos is not None: + if incremental_state is not None: + offset = src_len - 1 + else: + offset = 0 + k = self.xpos(k, offset=0, downscale=True) + q = self.xpos(q, offset=offset, downscale=False) + attn_weights = torch.bmm(q, k.transpose(1, 2)) if attn_mask is not None: diff --git a/torchscale/component/xpos_relative_position.py b/torchscale/component/xpos_relative_position.py new file mode 100644 index 0000000..a3ec129 --- /dev/null +++ b/torchscale/component/xpos_relative_position.py @@ -0,0 +1,65 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import numpy as np +import torch +import torch.nn as nn + +def fixed_pos_embedding(x): + seq_len, dim = x.shape + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim)) + sinusoid_inp = ( + torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x) + ) + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + +def rotate_every_two(x): + x1 = x[:, :, ::2] + x2 = x[:, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ + +def duplicate_interleave(m): + """ + A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. + """ + dim0 = m.shape[0] + m = m.view(-1, 1) # flatten the matrix + m = m.repeat(1, 2) # repeat all elements into the 2nd dimension + m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy + return m + +def apply_rotary_pos_emb(x, sin, cos, scale=1): + sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos)) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +class XPOS(nn.Module): + def __init__( + self, head_dim, scale_base=512 + ): + super().__init__() + self.head_dim = head_dim + self.scale_base = scale_base + self.register_buffer( + "scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim) + ) + + def forward(self, x, offset=0, downscale=False): + length = x.shape[1] + min_pos = -(length + offset) // 2 + max_pos = length + offset + min_pos + scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None] + sin, cos = fixed_pos_embedding(scale) + + if scale.shape[0] > length: + scale = scale[-length:] + sin = sin[-length:] + cos = cos[-length:] + + if downscale: + scale = 1 / scale + + x = apply_rotary_pos_emb(x, sin, cos, scale) + return x