diff --git a/torchscale/architecture/config.py b/torchscale/architecture/config.py index 347ee23..a77a742 100644 --- a/torchscale/architecture/config.py +++ b/torchscale/architecture/config.py @@ -49,6 +49,7 @@ class EncoderConfig(object): self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) self.fsdp = kwargs.pop("fsdp", False) self.ddp_rank = kwargs.pop("ddp_rank", 0) + self.xpos = kwargs.pop("xpos", False) if self.deepnorm: self.encoder_normalize_before = False @@ -110,6 +111,7 @@ class DecoderConfig(object): self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) self.fsdp = kwargs.pop("fsdp", False) self.ddp_rank = kwargs.pop("ddp_rank", 0) + self.xpos = kwargs.pop("xpos", False) if self.deepnorm: self.decoder_normalize_before = False @@ -178,6 +180,7 @@ class EncoderDecoderConfig(object): self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) self.fsdp = kwargs.pop("fsdp", False) self.ddp_rank = kwargs.pop("ddp_rank", 0) + self.xpos = kwargs.pop("xpos", False) if self.deepnorm: self.encoder_normalize_before = False diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 819c6d3..b00e54d 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -9,6 +9,7 @@ from apex.normalization import FusedLayerNorm as LayerNorm from torch import nn from .multiway_network import MultiwayWrapper +from .xpos import XPOS class MultiheadAttention(nn.Module): @@ -44,6 +45,11 @@ class MultiheadAttention(nn.Module): else None ) self.dropout_module = torch.nn.Dropout(dropout, inplace=True) + self.xpos = ( + XPOS(self.head_dim) + if args.xpos and self.self_attention + else None + ) def reset_parameters(self): nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) @@ -99,6 +105,14 @@ class MultiheadAttention(nn.Module): ) src_len = k.size(1) + if self.xpos is not None: + if incremental_state is not None: + offset = src_len - 1 + else: + offset = 0 + k = self.xpos(k, downscale=True) + q = self.xpos(q, offset=offset) + attn_weights = torch.bmm(q, k.transpose(1, 2)) if attn_mask is not None: diff --git a/torchscale/component/xpos.py b/torchscale/component/xpos.py new file mode 100644 index 0000000..3ab2380 --- /dev/null +++ b/torchscale/component/xpos.py @@ -0,0 +1,65 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import numpy as np +import torch +import torch.nn as nn + +def fixed_pos_embedding(x): + seq_len, dim = x.shape + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim)) + sinusoid_inp = ( + torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x) + ) + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + +def rotate_every_two(x): + x1 = x[:, :, ::2] + x2 = x[:, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ + +def duplicate_interleave(m): + """ + A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. + """ + dim0 = m.shape[0] + m = m.view(-1, 1) # flatten the matrix + m = m.repeat(1, 2) # repeat all elements into the 2nd dimension + m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy + return m + +def apply_rotary_pos_emb(x, sin, cos, scale=1): + sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos)) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +class XPOS(nn.Module): + def __init__( + self, head_dim, scale_base=512 + ): + super().__init__() + self.head_dim = head_dim + self.scale_base = scale_base + self.register_buffer( + "scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim) + ) + + def forward(self, x, offset=0, downscale=False): + length = x.shape[1] + min_pos = -(length + offset) // 2 + max_pos = length + offset - min_pos + scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None] + sin, cos = fixed_pos_embedding(scale) + + if scale.shape[0] > length: + scale = scale[-length:] + sin = sin[-length:] + cos = cos[-length:] + + if downscale: + scale = 1 / scale + + x = apply_rotary_pos_emb(x, sin, cos, scale) + return x