diff --git a/.idea/misc.xml b/.idea/misc.xml
index a370a267..0adf3fba 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -3,5 +3,5 @@
-
+
\ No newline at end of file
diff --git a/.idea/mmsr.iml b/.idea/mmsr.iml
index b63a5458..b06487d4 100644
--- a/.idea/mmsr.iml
+++ b/.idea/mmsr.iml
@@ -9,7 +9,7 @@
-
+
diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py
index 727e611b..038be766 100644
--- a/codes/models/lucidrains/x_transformers.py
+++ b/codes/models/lucidrains/x_transformers.py
@@ -29,42 +29,54 @@ LayerIntermediates = namedtuple('Intermediates', [
'attn_intermediates'
])
+
# helpers
def exists(val):
return val is not None
+
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
+
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth
+
class always():
def __init__(self, val):
self.val = val
+
def __call__(self, *args, **kwargs):
return self.val
+
class not_equals():
def __init__(self, val):
self.val = val
+
def __call__(self, x, *args, **kwargs):
return x != self.val
+
class equals():
def __init__(self, val):
self.val = val
+
def __call__(self, x, *args, **kwargs):
return x == self.val
+
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
+
def l2norm(t):
- return F.normalize(t, p = 2, dim = -1)
+ return F.normalize(t, p=2, dim=-1)
+
# init helpers
@@ -73,37 +85,44 @@ def init_zero_(layer):
if exists(layer.bias):
nn.init.constant_(layer.bias, 0.)
+
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
+
def group_dict_by_key(cond, d):
- return_val = [dict(),dict()]
+ return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
+
def string_begins_with(prefix, str):
return str.startswith(prefix)
+
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
+
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
+
# activations
class ReluSquared(nn.Module):
def forward(self, x):
return F.relu(x) ** 2
+
# positional embeddings
class AbsolutePositionalEmbedding(nn.Module):
@@ -113,25 +132,27 @@ class AbsolutePositionalEmbedding(nn.Module):
self.emb = nn.Embedding(max_seq_len, dim)
def forward(self, x):
- n = torch.arange(x.shape[1], device = x.device)
+ n = torch.arange(x.shape[1], device=x.device)
pos_emb = self.emb(n)
pos_emb = rearrange(pos_emb, 'n d -> () n d')
return pos_emb * self.scale
+
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
- def forward(self, x, seq_dim = 1, offset = 0):
- t = torch.arange(x.shape[seq_dim], device = x.device).type_as(self.inv_freq) + offset
+ def forward(self, x, seq_dim=1, offset=0):
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return rearrange(emb, 'n d -> () n d')
+
class RelativePositionBias(nn.Module):
- def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
+ def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
super().__init__()
self.scale = scale
self.causal = causal
@@ -140,7 +161,7 @@ class RelativePositionBias(nn.Module):
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
- def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
+ def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
ret = 0
n = -relative_position
if not causal:
@@ -154,7 +175,7 @@ class RelativePositionBias(nn.Module):
is_small = n < max_exact
val_if_large = max_exact + (
- torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
@@ -163,35 +184,38 @@ class RelativePositionBias(nn.Module):
def forward(self, qk_dots):
i, j, device = *qk_dots.shape[-2:], qk_dots.device
- q_pos = torch.arange(i, dtype = torch.long, device = device)
- k_pos = torch.arange(j, dtype = torch.long, device = device)
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
rel_pos = k_pos[None, :] - q_pos[:, None]
- rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
+ rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
+ max_distance=self.max_distance)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j h -> () h i j')
return qk_dots + (bias * self.scale)
+
class AlibiPositionalBias(nn.Module):
def __init__(self, heads, **kwargs):
super().__init__()
self.heads = heads
slopes = torch.Tensor(self._get_slopes(heads))
slopes = rearrange(slopes, 'h -> () h () ()')
- self.register_buffer('slopes', slopes, persistent = False)
- self.register_buffer('bias', None, persistent = False)
+ self.register_buffer('slopes', slopes, persistent=False)
+ self.register_buffer('bias', None, persistent=False)
@staticmethod
def _get_slopes(heads):
def get_slopes_power_of_2(n):
- start = (2**(-2**-(math.log2(n)-3)))
+ start = (2 ** (-2 ** -(math.log2(n) - 3)))
ratio = start
- return [start*ratio**i for i in range(n)]
+ return [start * ratio ** i for i in range(n)]
if math.log2(heads).is_integer():
return get_slopes_power_of_2(heads)
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
- return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
+ :heads - closest_power_of_2]
def forward(self, qk_dots):
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
@@ -199,18 +223,19 @@ class AlibiPositionalBias(nn.Module):
if exists(self.bias) and self.bias.shape[-1] >= j:
return qk_dots + self.bias[..., :j]
- bias = torch.arange(j, device = device)
+ bias = torch.arange(j, device=device)
bias = rearrange(bias, 'j -> () () () j')
bias = bias * self.slopes
num_heads_unalibied = h - bias.shape[1]
bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
- self.register_buffer('bias', bias, persistent = False)
+ self.register_buffer('bias', bias, persistent=False)
return qk_dots + self.bias
+
class LearnedAlibiPositionalBias(AlibiPositionalBias):
- def __init__(self, heads, bidirectional = False):
+ def __init__(self, heads, bidirectional=False):
super().__init__(heads)
los_slopes = torch.log(self.slopes)
self.learned_logslopes = nn.Parameter(los_slopes)
@@ -228,10 +253,10 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias):
if exists(self.bias) and self.bias.shape[-1] >= j:
bias = self.bias[..., :i, :j]
else:
- i_arange = torch.arange(i, device = device)
- j_arange = torch.arange(j, device = device)
+ i_arange = torch.arange(i, device=device)
+ j_arange = torch.arange(j, device=device)
bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1')
- self.register_buffer('bias', bias, persistent = False)
+ self.register_buffer('bias', bias, persistent=False)
if self.bidirectional:
past_slopes = get_slopes(self.learned_logslopes)
@@ -243,6 +268,7 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias):
return qk_dots + bias
+
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -250,21 +276,24 @@ class RotaryEmbedding(nn.Module):
self.register_buffer('inv_freq', inv_freq)
def forward(self, max_seq_len, device):
- t = torch.arange(max_seq_len, device = device).type_as(self.inv_freq)
+ t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return rearrange(emb, 'n d -> () () n d')
+
def rotate_half(x):
- x = rearrange(x, '... (j d) -> ... j d', j = 2)
- x1, x2 = x.unbind(dim = -2)
- return torch.cat((-x2, x1), dim = -1)
+ x = rearrange(x, '... (j d) -> ... j d', j=2)
+ x1, x2 = x.unbind(dim=-2)
+ return torch.cat((-x2, x1), dim=-1)
+
def apply_rotary_pos_emb(t, freqs):
seq_len = t.shape[-2]
freqs = freqs[:, :, -seq_len:]
return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
+
# norms
class Scale(nn.Module):
@@ -282,6 +311,7 @@ class Scale(nn.Module):
return (scale_fn(out[0]), *out[1:])
+
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
@@ -297,49 +327,53 @@ class Rezero(nn.Module):
return (rezero_fn(out[0]), *out[1:])
+
class ScaleNorm(nn.Module):
- def __init__(self, dim, eps = 1e-5):
+ def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
- norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
- return x / norm.clamp(min = self.eps) * self.g
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
class RMSNorm(nn.Module):
- def __init__(self, dim, eps = 1e-8):
+ def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
- norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
- return x / norm.clamp(min = self.eps) * self.g
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
class RMSScaleShiftNorm(nn.Module):
- def __init__(self, dim, eps = 1e-8):
+ def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
- self.scale_shift_process = nn.Linear(dim*2, dim*2)
+ self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
def forward(self, x, norm_scale_shift_inp):
- norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
- norm = x / norm.clamp(min = self.eps) * self.g
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ norm = x / norm.clamp(min=self.eps) * self.g
ss_emb = self.scale_shift_process(norm_scale_shift_inp)
scale, shift = torch.chunk(ss_emb, 2, dim=1)
h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return h
+
# residual and residual gates
class Residual(nn.Module):
- def __init__(self, dim, scale_residual = False):
+ def __init__(self, dim, scale_residual=False):
super().__init__()
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
@@ -349,8 +383,9 @@ class Residual(nn.Module):
return x + residual
+
class GRUGating(nn.Module):
- def __init__(self, dim, scale_residual = False):
+ def __init__(self, dim, scale_residual=False):
super().__init__()
self.gru = nn.GRUCell(dim, dim)
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
@@ -366,16 +401,18 @@ class GRUGating(nn.Module):
return gated_output.reshape_as(x)
+
# token shifting
-def shift(t, amount, mask = None):
+def shift(t, amount, mask=None):
if amount == 0:
return t
if exists(mask):
t = t.masked_fill(~mask[..., None], 0.)
- return F.pad(t, (0, 0, amount, -amount), value = 0.)
+ return F.pad(t, (0, 0, amount, -amount), value=0.)
+
class ShiftTokens(nn.Module):
def __init__(self, shifts, fn):
@@ -388,12 +425,13 @@ class ShiftTokens(nn.Module):
shifts = self.shifts
segments = len(shifts)
feats_per_shift = x.shape[-1] // segments
- splitted = x.split(feats_per_shift, dim = -1)
+ splitted = x.split(feats_per_shift, dim=-1)
segments_to_shift, rest = splitted[:segments], splitted[segments:]
- segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
- x = torch.cat((*segments_to_shift, *rest), dim = -1)
+ segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
+ x = torch.cat((*segments_to_shift, *rest), dim=-1)
return self.fn(x, **kwargs)
+
# feedforward
class GLU(nn.Module):
@@ -403,20 +441,21 @@ class GLU(nn.Module):
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
- x, gate = self.proj(x).chunk(2, dim = -1)
+ x, gate = self.proj(x).chunk(2, dim=-1)
return x * self.act(gate)
+
class FeedForward(nn.Module):
def __init__(
- self,
- dim,
- dim_out = None,
- mult = 4,
- glu = False,
- relu_squared = False,
- post_act_ln = False,
- dropout = 0.,
- zero_init_output = False
+ self,
+ dim,
+ dim_out=None,
+ mult=4,
+ glu=False,
+ relu_squared=False,
+ post_act_ln=False,
+ dropout=0.,
+ zero_init_output=False
):
super().__init__()
inner_dim = int(dim * mult)
@@ -442,32 +481,33 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
+
# attention.
class Attention(nn.Module):
def __init__(
- self,
- dim,
- dim_head = DEFAULT_DIM_HEAD,
- heads = 8,
- causal = False,
- talking_heads = False,
- head_scale = False,
- collab_heads = False,
- collab_compression = .3,
- sparse_topk = None,
- use_entmax15 = False,
- num_mem_kv = 0,
- dropout = 0.,
- on_attn = False,
- gate_values = False,
- zero_init_output = False,
- max_attend_past = None,
- qk_norm = False,
- scale_init_value = None,
- rel_pos_bias = False,
- rel_pos_num_buckets = 32,
- rel_pos_max_distance = 128,
+ self,
+ dim,
+ dim_head=DEFAULT_DIM_HEAD,
+ heads=8,
+ causal=False,
+ talking_heads=False,
+ head_scale=False,
+ collab_heads=False,
+ collab_compression=.3,
+ sparse_topk=None,
+ use_entmax15=False,
+ num_mem_kv=0,
+ dropout=0.,
+ on_attn=False,
+ gate_values=False,
+ zero_init_output=False,
+ max_attend_past=None,
+ qk_norm=False,
+ scale_init_value=None,
+ rel_pos_bias=False,
+ rel_pos_num_buckets=32,
+ rel_pos_max_distance=128,
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -484,9 +524,9 @@ class Attention(nn.Module):
qk_dim = int(collab_compression * qk_dim)
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
- self.to_q = nn.Linear(dim, qk_dim, bias = False)
- self.to_k = nn.Linear(dim, qk_dim, bias = False)
- self.to_v = nn.Linear(dim, v_dim, bias = False)
+ self.to_q = nn.Linear(dim, qk_dim, bias=False)
+ self.to_k = nn.Linear(dim, qk_dim, bias=False)
+ self.to_v = nn.Linear(dim, v_dim, bias=False)
self.dropout = nn.Dropout(dropout)
@@ -500,7 +540,8 @@ class Attention(nn.Module):
# cosine sim attention
self.qk_norm = qk_norm
if qk_norm:
- scale_init_value = default(scale_init_value, -3) # if not provided, initialize as though it were sequence length of 1024
+ scale_init_value = default(scale_init_value,
+ -3) # if not provided, initialize as though it were sequence length of 1024
self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
# talking heads
@@ -533,25 +574,27 @@ class Attention(nn.Module):
self.rel_pos_bias = rel_pos_bias
if rel_pos_bias:
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
- self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
+ self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads,
+ num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
# init output projection 0
if zero_init_output:
init_zero_(self.to_out)
def forward(
- self,
- x,
- context = None,
- mask = None,
- context_mask = None,
- attn_mask = None,
- sinusoidal_emb = None,
- rotary_pos_emb = None,
- prev_attn = None,
- mem = None
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ attn_mask=None,
+ sinusoidal_emb=None,
+ rotary_pos_emb=None,
+ prev_attn=None,
+ mem=None
):
- b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(context)
+ b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
+ context)
kv_input = default(context, x)
q_input = x
@@ -559,13 +602,13 @@ class Attention(nn.Module):
v_input = kv_input
if exists(mem):
- k_input = torch.cat((mem, k_input), dim = -2)
- v_input = torch.cat((mem, v_input), dim = -2)
+ k_input = torch.cat((mem, k_input), dim=-2)
+ v_input = torch.cat((mem, v_input), dim=-2)
if exists(sinusoidal_emb):
# in shortformer, the query would start at a position offset depending on the past cached memory
offset = k_input.shape[-2] - q_input.shape[-2]
- q_input = q_input + sinusoidal_emb(q_input, offset = offset)
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
k_input = k_input + sinusoidal_emb(k_input)
q = self.to_q(q_input)
@@ -573,40 +616,40 @@ class Attention(nn.Module):
v = self.to_v(v_input)
if not collab_heads:
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
else:
q = einsum('b i d, h d -> b h i d', q, self.collab_mixing)
k = rearrange(k, 'b n d -> b () n d')
- v = rearrange(v, 'b n (h d) -> b h n d', h = h)
+ v = rearrange(v, 'b n (h d) -> b h n d', h=h)
if exists(rotary_pos_emb) and not has_context:
l = rotary_pos_emb.shape[-1]
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
- q, k, v = map(lambda t: torch.cat(t, dim = -1), ((ql, qr), (kl, kr), (vl, vr)))
+ q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
input_mask = None
if any(map(exists, (mask, context_mask))):
- q_mask = default(mask, lambda: torch.ones((b, n), device = device).bool())
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
k_mask = q_mask if not exists(context) else context_mask
- k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device = device).bool())
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
q_mask = rearrange(q_mask, 'b i -> b () i ()')
k_mask = rearrange(k_mask, 'b j -> b () () j')
input_mask = q_mask * k_mask
if self.num_mem_kv > 0:
- mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
- k = torch.cat((mem_k, k), dim = -2)
- v = torch.cat((mem_v, v), dim = -2)
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
+ k = torch.cat((mem_k, k), dim=-2)
+ v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask):
- input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value = True)
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
if collab_heads:
k = k.expand(-1, h, -1, -1)
if self.qk_norm:
q, k = map(l2norm, (q, k))
- scale = 1 / (self.scale.exp().clamp(min = 1e-2))
+ scale = 1 / (self.scale.exp().clamp(min=1e-2))
dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
mask_value = max_neg_value(dots)
@@ -636,8 +679,8 @@ class Attention(nn.Module):
if exists(self.max_attend_past):
i, j = dots.shape[-2:]
- range_q = torch.arange(j - i, j, device = device)
- range_k = torch.arange(j, device = device)
+ range_q = torch.arange(j - i, j, device=device)
+ range_k = torch.arange(j, device=device)
dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j')
mask = dist > self.max_attend_past
dots.masked_fill_(mask, mask_value)
@@ -645,20 +688,20 @@ class Attention(nn.Module):
if self.causal:
i, j = dots.shape[-2:]
- r = torch.arange(i, device = device)
+ r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
- mask = F.pad(mask, (j - i, 0), value = False)
+ mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value)
del mask
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
- top, _ = dots.topk(self.sparse_topk, dim = -1)
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
mask = dots < vk
dots.masked_fill_(mask, mask_value)
del mask
- attn = self.attn_fn(dots, dim = -1)
+ attn = self.attn_fn(dots, dim=-1)
post_softmax_attn = attn.clone()
attn = self.dropout(attn)
@@ -678,46 +721,47 @@ class Attention(nn.Module):
out = out * gates.sigmoid()
intermediates = Intermediates(
- pre_softmax_attn = pre_softmax_attn,
- post_softmax_attn = post_softmax_attn
+ pre_softmax_attn=pre_softmax_attn,
+ post_softmax_attn=post_softmax_attn
)
return self.to_out(out), intermediates
+
class AttentionLayers(nn.Module):
def __init__(
- self,
- dim,
- depth,
- heads = 8,
- causal = False,
- cross_attend = False,
- only_cross = False,
- use_scalenorm = False,
- use_rms_scaleshift_norm = False,
- use_rmsnorm = False,
- use_rezero = False,
- alibi_pos_bias = False,
- alibi_num_heads = None,
- alibi_learned = False,
- position_infused_attn = False,
- rotary_pos_emb = False,
- rotary_emb_dim = None,
- custom_layers = None,
- sandwich_coef = None,
- par_ratio = None,
- residual_attn = False,
- cross_residual_attn = False,
- macaron = False,
- pre_norm = True,
- gate_residual = False,
- scale_residual = False,
- shift_tokens = 0,
- sandwich_norm = False,
- use_qk_norm_attn = False,
- qk_norm_attn_seq_len = None,
- zero_init_branch_output = False,
- **kwargs
+ self,
+ dim,
+ depth,
+ heads=8,
+ causal=False,
+ cross_attend=False,
+ only_cross=False,
+ use_scalenorm=False,
+ use_rms_scaleshift_norm=False,
+ use_rmsnorm=False,
+ use_rezero=False,
+ alibi_pos_bias=False,
+ alibi_num_heads=None,
+ alibi_learned=False,
+ position_infused_attn=False,
+ rotary_pos_emb=False,
+ rotary_emb_dim=None,
+ custom_layers=None,
+ sandwich_coef=None,
+ par_ratio=None,
+ residual_attn=False,
+ cross_residual_attn=False,
+ macaron=False,
+ pre_norm=True,
+ gate_residual=False,
+ scale_residual=False,
+ shift_tokens=0,
+ sandwich_norm=False,
+ use_qk_norm_attn=False,
+ qk_norm_attn_seq_len=None,
+ zero_init_branch_output=False,
+ **kwargs
):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
@@ -736,13 +780,14 @@ class AttentionLayers(nn.Module):
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
- assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
+ assert not (
+ alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
if alibi_pos_bias:
alibi_num_heads = default(alibi_num_heads, heads)
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
- self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, bidirectional = not causal)
+ self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal)
else:
self.rel_pos = None
@@ -775,14 +820,15 @@ class AttentionLayers(nn.Module):
# qk normalization
if use_qk_norm_attn:
- attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
+ attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(
+ qk_norm_attn_seq_len) else None
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
# zero init
if zero_init_branch_output:
- attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
- ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
# calculate layer block order
@@ -792,7 +838,7 @@ class AttentionLayers(nn.Module):
par_depth = depth * len(default_block)
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block))
- par_attn = par_depth // par_ratio
+ par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
@@ -818,9 +864,9 @@ class AttentionLayers(nn.Module):
is_last_layer = ind == (len(self.layer_types) - 1)
if layer_type == 'a':
- layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
elif layer_type == 'c':
- layer = Attention(dim, heads = heads, **attn_kwargs)
+ layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f':
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
@@ -836,7 +882,7 @@ class AttentionLayers(nn.Module):
layer = branch_fn(layer)
residual_fn = GRUGating if gate_residual else Residual
- residual = residual_fn(dim, scale_residual = scale_residual)
+ residual = residual_fn(dim, scale_residual=scale_residual)
layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
@@ -857,19 +903,20 @@ class AttentionLayers(nn.Module):
]))
def forward(
- self,
- x,
- context = None,
- full_context = None, # for passing a list of hidden states from an encoder
- mask = None,
- context_mask = None,
- attn_mask = None,
- mems = None,
- return_hiddens = False,
- norm_scale_shift_inp = None,
+ self,
+ x,
+ context=None,
+ full_context=None, # for passing a list of hidden states from an encoder
+ mask=None,
+ context_mask=None,
+ attn_mask=None,
+ mems=None,
+ return_hiddens=False,
+ norm_scale_shift_inp=None,
):
-
- assert not (self.cross_attend ^ (exists(context) or exists(full_context))), 'context must be passed in if cross_attend is set to True'
+
+ assert not (self.cross_attend ^ (exists(context) or exists(
+ full_context))), 'context must be passed in if cross_attend is set to True'
assert context is None or full_context is None, 'only one of full_context or context can be provided'
hiddens = []
@@ -900,12 +947,14 @@ class AttentionLayers(nn.Module):
x = pre_branch_norm(x, **norm_args)
if layer_type == 'a':
- out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem)
+ out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
+ prev_attn, layer_mem)
elif layer_type == 'c':
if exists(full_context):
- out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, None, prev_attn)
+ out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
+ None, prev_attn)
else:
- out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
+ out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
elif layer_type == 'f':
out = checkpoint(block, x)
@@ -933,38 +982,42 @@ class AttentionLayers(nn.Module):
if return_hiddens:
intermediates = LayerIntermediates(
- hiddens = hiddens,
- attn_intermediates = intermediates
+ hiddens=hiddens,
+ attn_intermediates=intermediates
)
return x, intermediates
return x
+
class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on encoder'
- super().__init__(causal = False, **kwargs)
+ super().__init__(causal=False, **kwargs)
+
class Decoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on decoder'
- super().__init__(causal = True, **kwargs)
+ super().__init__(causal=True, **kwargs)
+
class CrossAttender(AttentionLayers):
def __init__(self, **kwargs):
- super().__init__(cross_attend = True, only_cross = True, **kwargs)
+ super().__init__(cross_attend=True, only_cross=True, **kwargs)
+
class ViTransformerWrapper(nn.Module):
def __init__(
- self,
- *,
- image_size,
- patch_size,
- attn_layers,
- num_classes = None,
- dropout = 0.,
- emb_dropout = 0.
+ self,
+ *,
+ image_size,
+ patch_size,
+ attn_layers,
+ num_classes=None,
+ dropout=0.,
+ emb_dropout=0.
):
super().__init__()
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
@@ -982,20 +1035,20 @@ class ViTransformerWrapper(nn.Module):
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
- self.mlp_head = FeedForward(dim, dim_out = num_classes, dropout = dropout) if exists(num_classes) else None
+ self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None
def forward(
- self,
- img,
- return_embeddings = False
+ self,
+ img,
+ return_embeddings=False
):
p = self.patch_size
- x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
- cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
@@ -1008,20 +1061,21 @@ class ViTransformerWrapper(nn.Module):
return self.mlp_head(x[:, 0])
+
class TransformerWrapper(nn.Module):
def __init__(
- self,
- *,
- num_tokens,
- max_seq_len,
- attn_layers,
- emb_dim = None,
- max_mem_len = 0.,
- shift_mem_down = 0,
- emb_dropout = 0.,
- num_memory_tokens = None,
- tie_embedding = False,
- use_pos_emb = True
+ self,
+ *,
+ num_tokens,
+ max_seq_len,
+ attn_layers,
+ emb_dim=None,
+ max_mem_len=0.,
+ shift_mem_down=0,
+ emb_dropout=0.,
+ num_memory_tokens=None,
+ tie_embedding=False,
+ use_pos_emb=True
):
super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
@@ -1034,7 +1088,8 @@ class TransformerWrapper(nn.Module):
self.shift_mem_down = shift_mem_down
self.token_emb = nn.Embedding(num_tokens, emb_dim)
- self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0)
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
@@ -1055,14 +1110,14 @@ class TransformerWrapper(nn.Module):
nn.init.kaiming_normal_(self.token_emb.weight)
def forward(
- self,
- x,
- return_embeddings = False,
- mask = None,
- return_hiddens = False,
- return_attn = False,
- mems = None,
- **kwargs
+ self,
+ x,
+ return_embeddings=False,
+ mask=None,
+ return_hiddens=False,
+ return_attn=False,
+ mems=None,
+ **kwargs
):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
x = self.token_emb(x)
@@ -1072,18 +1127,18 @@ class TransformerWrapper(nn.Module):
x = self.project_emb(x)
if num_mem > 0:
- mem = repeat(self.memory_tokens, 'n d -> b n d', b = b)
- x = torch.cat((mem, x), dim = 1)
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
+ x = torch.cat((mem, x), dim=1)
# auto-handle masking after appending memory tokens
if exists(mask):
- mask = F.pad(mask, (num_mem, 0), value = True)
+ mask = F.pad(mask, (num_mem, 0), value=True)
if self.shift_mem_down and exists(mems):
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
mems = [*mems_r, *mems_l]
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:]
@@ -1100,17 +1155,18 @@ class TransformerWrapper(nn.Module):
return out
+
class ContinuousTransformerWrapper(nn.Module):
def __init__(
- self,
- *,
- max_seq_len,
- attn_layers,
- dim_in = None,
- dim_out = None,
- emb_dim = None,
- emb_dropout = 0.,
- use_pos_emb = True
+ self,
+ *,
+ max_seq_len,
+ attn_layers,
+ dim_in=None,
+ dim_out=None,
+ emb_dim=None,
+ emb_dropout=0.,
+ use_pos_emb=True
):
super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
@@ -1119,7 +1175,8 @@ class ContinuousTransformerWrapper(nn.Module):
self.max_seq_len = max_seq_len
- self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0)
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
@@ -1130,13 +1187,13 @@ class ContinuousTransformerWrapper(nn.Module):
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
def forward(
- self,
- x,
- return_embeddings = False,
- mask = None,
- return_attn = False,
- mems = None,
- **kwargs
+ self,
+ x,
+ return_embeddings=False,
+ mask=None,
+ return_attn=False,
+ mems=None,
+ **kwargs
):
b, n, _, device = *x.shape, x.device
@@ -1144,7 +1201,7 @@ class ContinuousTransformerWrapper(nn.Module):
x = x + self.pos_emb(x)
x = self.emb_dropout(x)
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x = self.norm(x)
out = self.project_out(x) if not return_embeddings else x
@@ -1155,13 +1212,14 @@ class ContinuousTransformerWrapper(nn.Module):
return out
+
class XTransformer(nn.Module):
def __init__(
- self,
- *,
- dim,
- tie_token_emb = False,
- **kwargs
+ self,
+ *,
+ dim,
+ tie_token_emb=False,
+ **kwargs
):
super().__init__()
enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
@@ -1179,12 +1237,12 @@ class XTransformer(nn.Module):
self.encoder = TransformerWrapper(
**enc_transformer_kwargs,
- attn_layers = Encoder(dim = dim, **enc_kwargs)
+ attn_layers=Encoder(dim=dim, **enc_kwargs)
)
self.decoder = TransformerWrapper(
**dec_transformer_kwargs,
- attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
+ attn_layers=Decoder(dim=dim, cross_attend=True, **dec_kwargs)
)
if tie_token_emb:
@@ -1193,11 +1251,11 @@ class XTransformer(nn.Module):
self.decoder = AutoregressiveWrapper(self.decoder)
@torch.no_grad()
- def generate(self, seq_in, seq_out_start, seq_len, src_mask = None, src_attn_mask = None, **kwargs):
- encodings = self.encoder(seq_in, mask = src_mask, attn_mask = src_attn_mask, return_embeddings = True)
- return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = src_mask, **kwargs)
+ def generate(self, seq_in, seq_out_start, seq_len, src_mask=None, src_attn_mask=None, **kwargs):
+ encodings = self.encoder(seq_in, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
+ return self.decoder.generate(seq_out_start, seq_len, context=encodings, context_mask=src_mask, **kwargs)
- def forward(self, src, tgt, src_mask = None, tgt_mask = None, src_attn_mask = None):
- enc = self.encoder(src, mask = src_mask, attn_mask = src_attn_mask, return_embeddings = True)
- out = self.decoder(tgt, context = enc, mask = tgt_mask, context_mask = src_mask)
+ def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_attn_mask=None):
+ enc = self.encoder(src, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
+ out = self.decoder(tgt, context=enc, mask=tgt_mask, context_mask=src_mask)
return out