reformat x_transformers

This commit is contained in:
James Betker 2022-04-07 23:08:03 -07:00
parent 7c578eb59b
commit d05e162f95
3 changed files with 303 additions and 245 deletions

View File

@ -3,5 +3,5 @@
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (torch)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (torch)" project-jdk-type="Python SDK" />
</project>

View File

@ -9,7 +9,7 @@
<excludeFolder url="file://$MODULE_DIR$/results" />
<excludeFolder url="file://$MODULE_DIR$/tb_logger" />
</content>
<orderEntry type="jdk" jdkName="Python 3.8 (torch)" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Python 3.9 (torch)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">

View File

@ -29,43 +29,55 @@ LayerIntermediates = namedtuple('Intermediates', [
'attn_intermediates'
])
# helpers
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth
class always():
def __init__(self, val):
self.val = val
def __call__(self, *args, **kwargs):
return self.val
class not_equals():
def __init__(self, val):
self.val = val
def __call__(self, x, *args, **kwargs):
return x != self.val
class equals():
def __init__(self, val):
self.val = val
def __call__(self, x, *args, **kwargs):
return x == self.val
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
def l2norm(t):
return F.normalize(t, p=2, dim=-1)
# init helpers
def init_zero_(layer):
@ -73,12 +85,14 @@ def init_zero_(layer):
if exists(layer.bias):
nn.init.constant_(layer.bias, 0.)
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
@ -87,23 +101,28 @@ def group_dict_by_key(cond, d):
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# activations
class ReluSquared(nn.Module):
def forward(self, x):
return F.relu(x) ** 2
# positional embeddings
class AbsolutePositionalEmbedding(nn.Module):
@ -118,6 +137,7 @@ class AbsolutePositionalEmbedding(nn.Module):
pos_emb = rearrange(pos_emb, 'n d -> () n d')
return pos_emb * self.scale
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
@ -130,6 +150,7 @@ class FixedPositionalEmbedding(nn.Module):
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return rearrange(emb, 'n d -> () n d')
class RelativePositionBias(nn.Module):
def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
super().__init__()
@ -166,11 +187,13 @@ class RelativePositionBias(nn.Module):
q_pos = torch.arange(i, dtype=torch.long, device=device)
k_pos = torch.arange(j, dtype=torch.long, device=device)
rel_pos = k_pos[None, :] - q_pos[:, None]
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
max_distance=self.max_distance)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j h -> () h i j')
return qk_dots + (bias * self.scale)
class AlibiPositionalBias(nn.Module):
def __init__(self, heads, **kwargs):
super().__init__()
@ -191,7 +214,8 @@ class AlibiPositionalBias(nn.Module):
return get_slopes_power_of_2(heads)
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
:heads - closest_power_of_2]
def forward(self, qk_dots):
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
@ -209,6 +233,7 @@ class AlibiPositionalBias(nn.Module):
self.register_buffer('bias', bias, persistent=False)
return qk_dots + self.bias
class LearnedAlibiPositionalBias(AlibiPositionalBias):
def __init__(self, heads, bidirectional=False):
super().__init__(heads)
@ -243,6 +268,7 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias):
return qk_dots + bias
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
@ -255,16 +281,19 @@ class RotaryEmbedding(nn.Module):
emb = torch.cat((freqs, freqs), dim=-1)
return rearrange(emb, 'n d -> () () n d')
def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs):
seq_len = t.shape[-2]
freqs = freqs[:, :, -seq_len:]
return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
# norms
class Scale(nn.Module):
@ -282,6 +311,7 @@ class Scale(nn.Module):
return (scale_fn(out[0]), *out[1:])
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
@ -297,6 +327,7 @@ class Rezero(nn.Module):
return (rezero_fn(out[0]), *out[1:])
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
@ -308,6 +339,7 @@ class ScaleNorm(nn.Module):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
@ -319,6 +351,7 @@ class RMSNorm(nn.Module):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class RMSScaleShiftNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
@ -336,6 +369,7 @@ class RMSScaleShiftNorm(nn.Module):
h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return h
# residual and residual gates
class Residual(nn.Module):
@ -349,6 +383,7 @@ class Residual(nn.Module):
return x + residual
class GRUGating(nn.Module):
def __init__(self, dim, scale_residual=False):
super().__init__()
@ -366,6 +401,7 @@ class GRUGating(nn.Module):
return gated_output.reshape_as(x)
# token shifting
def shift(t, amount, mask=None):
@ -377,6 +413,7 @@ def shift(t, amount, mask = None):
return F.pad(t, (0, 0, amount, -amount), value=0.)
class ShiftTokens(nn.Module):
def __init__(self, shifts, fn):
super().__init__()
@ -394,6 +431,7 @@ class ShiftTokens(nn.Module):
x = torch.cat((*segments_to_shift, *rest), dim=-1)
return self.fn(x, **kwargs)
# feedforward
class GLU(nn.Module):
@ -406,6 +444,7 @@ class GLU(nn.Module):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * self.act(gate)
class FeedForward(nn.Module):
def __init__(
self,
@ -442,6 +481,7 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
# attention.
class Attention(nn.Module):
@ -500,7 +540,8 @@ class Attention(nn.Module):
# cosine sim attention
self.qk_norm = qk_norm
if qk_norm:
scale_init_value = default(scale_init_value, -3) # if not provided, initialize as though it were sequence length of 1024
scale_init_value = default(scale_init_value,
-3) # if not provided, initialize as though it were sequence length of 1024
self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
# talking heads
@ -533,7 +574,8 @@ class Attention(nn.Module):
self.rel_pos_bias = rel_pos_bias
if rel_pos_bias:
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads,
num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
# init output projection 0
if zero_init_output:
@ -551,7 +593,8 @@ class Attention(nn.Module):
prev_attn=None,
mem=None
):
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(context)
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
context)
kv_input = default(context, x)
q_input = x
@ -684,6 +727,7 @@ class Attention(nn.Module):
return self.to_out(out), intermediates
class AttentionLayers(nn.Module):
def __init__(
self,
@ -736,7 +780,8 @@ class AttentionLayers(nn.Module):
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
assert not (
alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
if alibi_pos_bias:
alibi_num_heads = default(alibi_num_heads, heads)
@ -775,7 +820,8 @@ class AttentionLayers(nn.Module):
# qk normalization
if use_qk_norm_attn:
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(
qk_norm_attn_seq_len) else None
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
# zero init
@ -869,7 +915,8 @@ class AttentionLayers(nn.Module):
norm_scale_shift_inp=None,
):
assert not (self.cross_attend ^ (exists(context) or exists(full_context))), 'context must be passed in if cross_attend is set to True'
assert not (self.cross_attend ^ (exists(context) or exists(
full_context))), 'context must be passed in if cross_attend is set to True'
assert context is None or full_context is None, 'only one of full_context or context can be provided'
hiddens = []
@ -900,10 +947,12 @@ class AttentionLayers(nn.Module):
x = pre_branch_norm(x, **norm_args)
if layer_type == 'a':
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem)
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
prev_attn, layer_mem)
elif layer_type == 'c':
if exists(full_context):
out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, None, prev_attn)
out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
None, prev_attn)
else:
out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
elif layer_type == 'f':
@ -941,20 +990,24 @@ class AttentionLayers(nn.Module):
return x
class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on encoder'
super().__init__(causal=False, **kwargs)
class Decoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on decoder'
super().__init__(causal=True, **kwargs)
class CrossAttender(AttentionLayers):
def __init__(self, **kwargs):
super().__init__(cross_attend=True, only_cross=True, **kwargs)
class ViTransformerWrapper(nn.Module):
def __init__(
self,
@ -1008,6 +1061,7 @@ class ViTransformerWrapper(nn.Module):
return self.mlp_head(x[:, 0])
class TransformerWrapper(nn.Module):
def __init__(
self,
@ -1034,7 +1088,8 @@ class TransformerWrapper(nn.Module):
self.shift_mem_down = shift_mem_down
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
@ -1100,6 +1155,7 @@ class TransformerWrapper(nn.Module):
return out
class ContinuousTransformerWrapper(nn.Module):
def __init__(
self,
@ -1119,7 +1175,8 @@ class ContinuousTransformerWrapper(nn.Module):
self.max_seq_len = max_seq_len
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
@ -1155,6 +1212,7 @@ class ContinuousTransformerWrapper(nn.Module):
return out
class XTransformer(nn.Module):
def __init__(
self,