diff --git a/README.md b/README.md index 9d469cc..7deac03 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,9 @@ We also support the `Decoder` architecture and the `EncoderDecoder` architecture * enabled by *multiway=True*. * It provides a pool of Transformer's parameters used for different modalities. +- [Extrapolatable position embedding (Xpos)](https://arxiv.org/abs/2212.10554) + * enabled by *xpos_rel_pos=True*. + - [Relative position bias](https://arxiv.org/abs/1910.10683) * enabled by adjusting *rel_pos_buckets* and *max_rel_pos*. diff --git a/examples/fairseq/models/language_modeling.py b/examples/fairseq/models/language_modeling.py index 7ec7b33..71bf1a5 100644 --- a/examples/fairseq/models/language_modeling.py +++ b/examples/fairseq/models/language_modeling.py @@ -190,6 +190,12 @@ class LanguageConfig(FairseqDataclass): max_rel_pos: Optional[int] = field( default=0, ) + xpos_rel_pos: Optional[bool] = field( + default=False, + ) + xpos_scale_base: Optional[int] = field( + default=512, + ) @register_model("lm", dataclass=LanguageConfig) diff --git a/torchscale/architecture/config.py b/torchscale/architecture/config.py index a77a742..6b8bb5c 100644 --- a/torchscale/architecture/config.py +++ b/torchscale/architecture/config.py @@ -49,7 +49,8 @@ class EncoderConfig(object): self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) self.fsdp = kwargs.pop("fsdp", False) self.ddp_rank = kwargs.pop("ddp_rank", 0) - self.xpos = kwargs.pop("xpos", False) + self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) + self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) if self.deepnorm: self.encoder_normalize_before = False @@ -111,7 +112,8 @@ class DecoderConfig(object): self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) self.fsdp = kwargs.pop("fsdp", False) self.ddp_rank = kwargs.pop("ddp_rank", 0) - self.xpos = kwargs.pop("xpos", False) + self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) + self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) if self.deepnorm: self.decoder_normalize_before = False @@ -180,7 +182,8 @@ class EncoderDecoderConfig(object): self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) self.fsdp = kwargs.pop("fsdp", False) self.ddp_rank = kwargs.pop("ddp_rank", 0) - self.xpos = kwargs.pop("xpos", False) + self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) + self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) if self.deepnorm: self.encoder_normalize_before = False diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index b00e54d..3c67c28 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -9,7 +9,7 @@ from apex.normalization import FusedLayerNorm as LayerNorm from torch import nn from .multiway_network import MultiwayWrapper -from .xpos import XPOS +from .xpos_relative_position import XPOS class MultiheadAttention(nn.Module): @@ -46,8 +46,8 @@ class MultiheadAttention(nn.Module): ) self.dropout_module = torch.nn.Dropout(dropout, inplace=True) self.xpos = ( - XPOS(self.head_dim) - if args.xpos and self.self_attention + XPOS(self.head_dim, args.xpos_scale_base) + if args.xpos_rel_pos and self.self_attention else None ) @@ -110,8 +110,8 @@ class MultiheadAttention(nn.Module): offset = src_len - 1 else: offset = 0 - k = self.xpos(k, downscale=True) - q = self.xpos(q, offset=offset) + k = self.xpos(k, offset=0, downscale=True) + q = self.xpos(q, offset=offset, downscale=False) attn_weights = torch.bmm(q, k.transpose(1, 2)) diff --git a/torchscale/component/xpos.py b/torchscale/component/xpos_relative_position.py similarity index 97% rename from torchscale/component/xpos.py rename to torchscale/component/xpos_relative_position.py index 3ab2380..a3ec129 100644 --- a/torchscale/component/xpos.py +++ b/torchscale/component/xpos_relative_position.py @@ -49,7 +49,7 @@ class XPOS(nn.Module): def forward(self, x, offset=0, downscale=False): length = x.shape[1] min_pos = -(length + offset) // 2 - max_pos = length + offset - min_pos + max_pos = length + offset + min_pos scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None] sin, cos = fixed_pos_embedding(scale)