Merge pull request #11 from microsoft/xpos
Adding the official implementation of Xpos (https://arxiv.org/abs/2212.10554)
This commit is contained in:
commit
776b070d68
|
@ -84,6 +84,9 @@ We also support the `Decoder` architecture and the `EncoderDecoder` architecture
|
|||
* enabled by *multiway=True*.
|
||||
* It provides a pool of Transformer's parameters used for different modalities.
|
||||
|
||||
- [Extrapolatable position embedding (Xpos)](https://arxiv.org/abs/2212.10554)
|
||||
* enabled by *xpos_rel_pos=True*.
|
||||
|
||||
- [Relative position bias](https://arxiv.org/abs/1910.10683)
|
||||
* enabled by adjusting *rel_pos_buckets* and *max_rel_pos*.
|
||||
|
||||
|
|
|
@ -190,6 +190,12 @@ class LanguageConfig(FairseqDataclass):
|
|||
max_rel_pos: Optional[int] = field(
|
||||
default=0,
|
||||
)
|
||||
xpos_rel_pos: Optional[bool] = field(
|
||||
default=False,
|
||||
)
|
||||
xpos_scale_base: Optional[int] = field(
|
||||
default=512,
|
||||
)
|
||||
|
||||
|
||||
@register_model("lm", dataclass=LanguageConfig)
|
||||
|
|
|
@ -49,6 +49,8 @@ class EncoderConfig(object):
|
|||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||
self.fsdp = kwargs.pop("fsdp", False)
|
||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
||||
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
||||
|
||||
if self.deepnorm:
|
||||
self.encoder_normalize_before = False
|
||||
|
@ -110,6 +112,8 @@ class DecoderConfig(object):
|
|||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||
self.fsdp = kwargs.pop("fsdp", False)
|
||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
||||
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
||||
|
||||
if self.deepnorm:
|
||||
self.decoder_normalize_before = False
|
||||
|
@ -178,6 +182,8 @@ class EncoderDecoderConfig(object):
|
|||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||
self.fsdp = kwargs.pop("fsdp", False)
|
||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
||||
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
||||
|
||||
if self.deepnorm:
|
||||
self.encoder_normalize_before = False
|
||||
|
|
|
@ -9,6 +9,7 @@ from apex.normalization import FusedLayerNorm as LayerNorm
|
|||
from torch import nn
|
||||
|
||||
from .multiway_network import MultiwayWrapper
|
||||
from .xpos_relative_position import XPOS
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
|
@ -44,6 +45,11 @@ class MultiheadAttention(nn.Module):
|
|||
else None
|
||||
)
|
||||
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
||||
self.xpos = (
|
||||
XPOS(self.head_dim, args.xpos_scale_base)
|
||||
if args.xpos_rel_pos and self.self_attention
|
||||
else None
|
||||
)
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||
|
@ -99,6 +105,14 @@ class MultiheadAttention(nn.Module):
|
|||
)
|
||||
src_len = k.size(1)
|
||||
|
||||
if self.xpos is not None:
|
||||
if incremental_state is not None:
|
||||
offset = src_len - 1
|
||||
else:
|
||||
offset = 0
|
||||
k = self.xpos(k, offset=0, downscale=True)
|
||||
q = self.xpos(q, offset=offset, downscale=False)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
|
||||
if attn_mask is not None:
|
||||
|
|
65
torchscale/component/xpos_relative_position.py
Normal file
65
torchscale/component/xpos_relative_position.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def fixed_pos_embedding(x):
|
||||
seq_len, dim = x.shape
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim))
|
||||
sinusoid_inp = (
|
||||
torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x)
|
||||
)
|
||||
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
|
||||
|
||||
def rotate_every_two(x):
|
||||
x1 = x[:, :, ::2]
|
||||
x2 = x[:, :, 1::2]
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
|
||||
|
||||
def duplicate_interleave(m):
|
||||
"""
|
||||
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
|
||||
"""
|
||||
dim0 = m.shape[0]
|
||||
m = m.view(-1, 1) # flatten the matrix
|
||||
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
|
||||
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
|
||||
return m
|
||||
|
||||
def apply_rotary_pos_emb(x, sin, cos, scale=1):
|
||||
sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos))
|
||||
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
|
||||
return (x * cos) + (rotate_every_two(x) * sin)
|
||||
|
||||
|
||||
class XPOS(nn.Module):
|
||||
def __init__(
|
||||
self, head_dim, scale_base=512
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.scale_base = scale_base
|
||||
self.register_buffer(
|
||||
"scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim)
|
||||
)
|
||||
|
||||
def forward(self, x, offset=0, downscale=False):
|
||||
length = x.shape[1]
|
||||
min_pos = -(length + offset) // 2
|
||||
max_pos = length + offset + min_pos
|
||||
scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None]
|
||||
sin, cos = fixed_pos_embedding(scale)
|
||||
|
||||
if scale.shape[0] > length:
|
||||
scale = scale[-length:]
|
||||
sin = sin[-length:]
|
||||
cos = cos[-length:]
|
||||
|
||||
if downscale:
|
||||
scale = 1 / scale
|
||||
|
||||
x = apply_rotary_pos_emb(x, sin, cos, scale)
|
||||
return x
|
Loading…
Reference in New Issue
Block a user