Add XPOS
This commit is contained in:
parent
aa36203042
commit
f9d98f4b68
|
@ -49,6 +49,7 @@ class EncoderConfig(object):
|
||||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||||
self.fsdp = kwargs.pop("fsdp", False)
|
self.fsdp = kwargs.pop("fsdp", False)
|
||||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||||
|
self.xpos = kwargs.pop("xpos", False)
|
||||||
|
|
||||||
if self.deepnorm:
|
if self.deepnorm:
|
||||||
self.encoder_normalize_before = False
|
self.encoder_normalize_before = False
|
||||||
|
@ -110,6 +111,7 @@ class DecoderConfig(object):
|
||||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||||
self.fsdp = kwargs.pop("fsdp", False)
|
self.fsdp = kwargs.pop("fsdp", False)
|
||||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||||
|
self.xpos = kwargs.pop("xpos", False)
|
||||||
|
|
||||||
if self.deepnorm:
|
if self.deepnorm:
|
||||||
self.decoder_normalize_before = False
|
self.decoder_normalize_before = False
|
||||||
|
@ -178,6 +180,7 @@ class EncoderDecoderConfig(object):
|
||||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||||
self.fsdp = kwargs.pop("fsdp", False)
|
self.fsdp = kwargs.pop("fsdp", False)
|
||||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||||
|
self.xpos = kwargs.pop("xpos", False)
|
||||||
|
|
||||||
if self.deepnorm:
|
if self.deepnorm:
|
||||||
self.encoder_normalize_before = False
|
self.encoder_normalize_before = False
|
||||||
|
|
|
@ -9,6 +9,7 @@ from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .multiway_network import MultiwayWrapper
|
from .multiway_network import MultiwayWrapper
|
||||||
|
from .xpos import XPOS
|
||||||
|
|
||||||
|
|
||||||
class MultiheadAttention(nn.Module):
|
class MultiheadAttention(nn.Module):
|
||||||
|
@ -44,6 +45,11 @@ class MultiheadAttention(nn.Module):
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
||||||
|
self.xpos = (
|
||||||
|
XPOS(self.head_dim)
|
||||||
|
if args.xpos and self.self_attention
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
@ -99,6 +105,14 @@ class MultiheadAttention(nn.Module):
|
||||||
)
|
)
|
||||||
src_len = k.size(1)
|
src_len = k.size(1)
|
||||||
|
|
||||||
|
if self.xpos is not None:
|
||||||
|
if incremental_state is not None:
|
||||||
|
offset = src_len - 1
|
||||||
|
else:
|
||||||
|
offset = 0
|
||||||
|
k = self.xpos(k, downscale=True)
|
||||||
|
q = self.xpos(q, offset=offset)
|
||||||
|
|
||||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
|
|
65
torchscale/component/xpos.py
Normal file
65
torchscale/component/xpos.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
def fixed_pos_embedding(x):
|
||||||
|
seq_len, dim = x.shape
|
||||||
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim))
|
||||||
|
sinusoid_inp = (
|
||||||
|
torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x)
|
||||||
|
)
|
||||||
|
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
|
||||||
|
|
||||||
|
def rotate_every_two(x):
|
||||||
|
x1 = x[:, :, ::2]
|
||||||
|
x2 = x[:, :, 1::2]
|
||||||
|
x = torch.stack((-x2, x1), dim=-1)
|
||||||
|
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
|
||||||
|
|
||||||
|
def duplicate_interleave(m):
|
||||||
|
"""
|
||||||
|
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
|
||||||
|
"""
|
||||||
|
dim0 = m.shape[0]
|
||||||
|
m = m.view(-1, 1) # flatten the matrix
|
||||||
|
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
|
||||||
|
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
|
||||||
|
return m
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(x, sin, cos, scale=1):
|
||||||
|
sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos))
|
||||||
|
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
|
||||||
|
return (x * cos) + (rotate_every_two(x) * sin)
|
||||||
|
|
||||||
|
|
||||||
|
class XPOS(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, head_dim, scale_base=512
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.scale_base = scale_base
|
||||||
|
self.register_buffer(
|
||||||
|
"scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, offset=0, downscale=False):
|
||||||
|
length = x.shape[1]
|
||||||
|
min_pos = -(length + offset) // 2
|
||||||
|
max_pos = length + offset - min_pos
|
||||||
|
scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None]
|
||||||
|
sin, cos = fixed_pos_embedding(scale)
|
||||||
|
|
||||||
|
if scale.shape[0] > length:
|
||||||
|
scale = scale[-length:]
|
||||||
|
sin = sin[-length:]
|
||||||
|
cos = cos[-length:]
|
||||||
|
|
||||||
|
if downscale:
|
||||||
|
scale = 1 / scale
|
||||||
|
|
||||||
|
x = apply_rotary_pos_emb(x, sin, cos, scale)
|
||||||
|
return x
|
Loading…
Reference in New Issue
Block a user