Update XPos
This commit is contained in:
parent
f9d98f4b68
commit
9d968a24ed
|
@ -84,6 +84,9 @@ We also support the `Decoder` architecture and the `EncoderDecoder` architecture
|
||||||
* enabled by *multiway=True*.
|
* enabled by *multiway=True*.
|
||||||
* It provides a pool of Transformer's parameters used for different modalities.
|
* It provides a pool of Transformer's parameters used for different modalities.
|
||||||
|
|
||||||
|
- [Extrapolatable position embedding (Xpos)](https://arxiv.org/abs/2212.10554)
|
||||||
|
* enabled by *xpos_rel_pos=True*.
|
||||||
|
|
||||||
- [Relative position bias](https://arxiv.org/abs/1910.10683)
|
- [Relative position bias](https://arxiv.org/abs/1910.10683)
|
||||||
* enabled by adjusting *rel_pos_buckets* and *max_rel_pos*.
|
* enabled by adjusting *rel_pos_buckets* and *max_rel_pos*.
|
||||||
|
|
||||||
|
|
|
@ -190,6 +190,12 @@ class LanguageConfig(FairseqDataclass):
|
||||||
max_rel_pos: Optional[int] = field(
|
max_rel_pos: Optional[int] = field(
|
||||||
default=0,
|
default=0,
|
||||||
)
|
)
|
||||||
|
xpos_rel_pos: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
xpos_scale_base: Optional[int] = field(
|
||||||
|
default=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_model("lm", dataclass=LanguageConfig)
|
@register_model("lm", dataclass=LanguageConfig)
|
||||||
|
|
|
@ -49,7 +49,8 @@ class EncoderConfig(object):
|
||||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||||
self.fsdp = kwargs.pop("fsdp", False)
|
self.fsdp = kwargs.pop("fsdp", False)
|
||||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||||
self.xpos = kwargs.pop("xpos", False)
|
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
||||||
|
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
||||||
|
|
||||||
if self.deepnorm:
|
if self.deepnorm:
|
||||||
self.encoder_normalize_before = False
|
self.encoder_normalize_before = False
|
||||||
|
@ -111,7 +112,8 @@ class DecoderConfig(object):
|
||||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||||
self.fsdp = kwargs.pop("fsdp", False)
|
self.fsdp = kwargs.pop("fsdp", False)
|
||||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||||
self.xpos = kwargs.pop("xpos", False)
|
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
||||||
|
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
||||||
|
|
||||||
if self.deepnorm:
|
if self.deepnorm:
|
||||||
self.decoder_normalize_before = False
|
self.decoder_normalize_before = False
|
||||||
|
@ -180,7 +182,8 @@ class EncoderDecoderConfig(object):
|
||||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||||
self.fsdp = kwargs.pop("fsdp", False)
|
self.fsdp = kwargs.pop("fsdp", False)
|
||||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||||
self.xpos = kwargs.pop("xpos", False)
|
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
||||||
|
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
||||||
|
|
||||||
if self.deepnorm:
|
if self.deepnorm:
|
||||||
self.encoder_normalize_before = False
|
self.encoder_normalize_before = False
|
||||||
|
|
|
@ -9,7 +9,7 @@ from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .multiway_network import MultiwayWrapper
|
from .multiway_network import MultiwayWrapper
|
||||||
from .xpos import XPOS
|
from .xpos_relative_position import XPOS
|
||||||
|
|
||||||
|
|
||||||
class MultiheadAttention(nn.Module):
|
class MultiheadAttention(nn.Module):
|
||||||
|
@ -46,8 +46,8 @@ class MultiheadAttention(nn.Module):
|
||||||
)
|
)
|
||||||
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
||||||
self.xpos = (
|
self.xpos = (
|
||||||
XPOS(self.head_dim)
|
XPOS(self.head_dim, args.xpos_scale_base)
|
||||||
if args.xpos and self.self_attention
|
if args.xpos_rel_pos and self.self_attention
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -110,8 +110,8 @@ class MultiheadAttention(nn.Module):
|
||||||
offset = src_len - 1
|
offset = src_len - 1
|
||||||
else:
|
else:
|
||||||
offset = 0
|
offset = 0
|
||||||
k = self.xpos(k, downscale=True)
|
k = self.xpos(k, offset=0, downscale=True)
|
||||||
q = self.xpos(q, offset=offset)
|
q = self.xpos(q, offset=offset, downscale=False)
|
||||||
|
|
||||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ class XPOS(nn.Module):
|
||||||
def forward(self, x, offset=0, downscale=False):
|
def forward(self, x, offset=0, downscale=False):
|
||||||
length = x.shape[1]
|
length = x.shape[1]
|
||||||
min_pos = -(length + offset) // 2
|
min_pos = -(length + offset) // 2
|
||||||
max_pos = length + offset - min_pos
|
max_pos = length + offset + min_pos
|
||||||
scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None]
|
scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None]
|
||||||
sin, cos = fixed_pos_embedding(scale)
|
sin, cos = fixed_pos_embedding(scale)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user