forked from mrq/DL-Art-School
reformat x_transformers
This commit is contained in:
parent
7c578eb59b
commit
d05e162f95
|
@ -3,5 +3,5 @@
|
|||
<component name="JavaScriptSettings">
|
||||
<option name="languageLevel" value="ES6" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (torch)" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (torch)" project-jdk-type="Python SDK" />
|
||||
</project>
|
|
@ -9,7 +9,7 @@
|
|||
<excludeFolder url="file://$MODULE_DIR$/results" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/tb_logger" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.8 (torch)" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.9 (torch)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
|
|
|
@ -29,43 +29,55 @@ LayerIntermediates = namedtuple('Intermediates', [
|
|||
'attn_intermediates'
|
||||
])
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def cast_tuple(val, depth):
|
||||
return val if isinstance(val, tuple) else (val,) * depth
|
||||
|
||||
|
||||
class always():
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.val
|
||||
|
||||
|
||||
class not_equals():
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def __call__(self, x, *args, **kwargs):
|
||||
return x != self.val
|
||||
|
||||
|
||||
class equals():
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def __call__(self, x, *args, **kwargs):
|
||||
return x == self.val
|
||||
|
||||
|
||||
def max_neg_value(tensor):
|
||||
return -torch.finfo(tensor.dtype).max
|
||||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, p=2, dim=-1)
|
||||
|
||||
|
||||
# init helpers
|
||||
|
||||
def init_zero_(layer):
|
||||
|
@ -73,12 +85,14 @@ def init_zero_(layer):
|
|||
if exists(layer.bias):
|
||||
nn.init.constant_(layer.bias, 0.)
|
||||
|
||||
|
||||
# keyword argument helpers
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(), dict()]
|
||||
for key in d.keys():
|
||||
|
@ -87,23 +101,28 @@ def group_dict_by_key(cond, d):
|
|||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
|
||||
# activations
|
||||
|
||||
class ReluSquared(nn.Module):
|
||||
def forward(self, x):
|
||||
return F.relu(x) ** 2
|
||||
|
||||
|
||||
# positional embeddings
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
|
@ -118,6 +137,7 @@ class AbsolutePositionalEmbedding(nn.Module):
|
|||
pos_emb = rearrange(pos_emb, 'n d -> () n d')
|
||||
return pos_emb * self.scale
|
||||
|
||||
|
||||
class FixedPositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
@ -130,6 +150,7 @@ class FixedPositionalEmbedding(nn.Module):
|
|||
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
||||
return rearrange(emb, 'n d -> () n d')
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
|
||||
super().__init__()
|
||||
|
@ -166,11 +187,13 @@ class RelativePositionBias(nn.Module):
|
|||
q_pos = torch.arange(i, dtype=torch.long, device=device)
|
||||
k_pos = torch.arange(j, dtype=torch.long, device=device)
|
||||
rel_pos = k_pos[None, :] - q_pos[:, None]
|
||||
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
||||
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance)
|
||||
values = self.relative_attention_bias(rp_bucket)
|
||||
bias = rearrange(values, 'i j h -> () h i j')
|
||||
return qk_dots + (bias * self.scale)
|
||||
|
||||
|
||||
class AlibiPositionalBias(nn.Module):
|
||||
def __init__(self, heads, **kwargs):
|
||||
super().__init__()
|
||||
|
@ -191,7 +214,8 @@ class AlibiPositionalBias(nn.Module):
|
|||
return get_slopes_power_of_2(heads)
|
||||
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
|
||||
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
|
||||
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
|
||||
:heads - closest_power_of_2]
|
||||
|
||||
def forward(self, qk_dots):
|
||||
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
|
||||
|
@ -209,6 +233,7 @@ class AlibiPositionalBias(nn.Module):
|
|||
self.register_buffer('bias', bias, persistent=False)
|
||||
return qk_dots + self.bias
|
||||
|
||||
|
||||
class LearnedAlibiPositionalBias(AlibiPositionalBias):
|
||||
def __init__(self, heads, bidirectional=False):
|
||||
super().__init__(heads)
|
||||
|
@ -243,6 +268,7 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias):
|
|||
|
||||
return qk_dots + bias
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
@ -255,16 +281,19 @@ class RotaryEmbedding(nn.Module):
|
|||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
return rearrange(emb, 'n d -> () () n d')
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, '... (j d) -> ... j d', j=2)
|
||||
x1, x2 = x.unbind(dim=-2)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(t, freqs):
|
||||
seq_len = t.shape[-2]
|
||||
freqs = freqs[:, :, -seq_len:]
|
||||
return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
|
||||
|
||||
|
||||
# norms
|
||||
|
||||
class Scale(nn.Module):
|
||||
|
@ -282,6 +311,7 @@ class Scale(nn.Module):
|
|||
|
||||
return (scale_fn(out[0]), *out[1:])
|
||||
|
||||
|
||||
class Rezero(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
|
@ -297,6 +327,7 @@ class Rezero(nn.Module):
|
|||
|
||||
return (rezero_fn(out[0]), *out[1:])
|
||||
|
||||
|
||||
class ScaleNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
|
@ -308,6 +339,7 @@ class ScaleNorm(nn.Module):
|
|||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-8):
|
||||
super().__init__()
|
||||
|
@ -319,6 +351,7 @@ class RMSNorm(nn.Module):
|
|||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class RMSScaleShiftNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-8):
|
||||
super().__init__()
|
||||
|
@ -336,6 +369,7 @@ class RMSScaleShiftNorm(nn.Module):
|
|||
h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
return h
|
||||
|
||||
|
||||
# residual and residual gates
|
||||
|
||||
class Residual(nn.Module):
|
||||
|
@ -349,6 +383,7 @@ class Residual(nn.Module):
|
|||
|
||||
return x + residual
|
||||
|
||||
|
||||
class GRUGating(nn.Module):
|
||||
def __init__(self, dim, scale_residual=False):
|
||||
super().__init__()
|
||||
|
@ -366,6 +401,7 @@ class GRUGating(nn.Module):
|
|||
|
||||
return gated_output.reshape_as(x)
|
||||
|
||||
|
||||
# token shifting
|
||||
|
||||
def shift(t, amount, mask=None):
|
||||
|
@ -377,6 +413,7 @@ def shift(t, amount, mask = None):
|
|||
|
||||
return F.pad(t, (0, 0, amount, -amount), value=0.)
|
||||
|
||||
|
||||
class ShiftTokens(nn.Module):
|
||||
def __init__(self, shifts, fn):
|
||||
super().__init__()
|
||||
|
@ -394,6 +431,7 @@ class ShiftTokens(nn.Module):
|
|||
x = torch.cat((*segments_to_shift, *rest), dim=-1)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
|
||||
# feedforward
|
||||
|
||||
class GLU(nn.Module):
|
||||
|
@ -406,6 +444,7 @@ class GLU(nn.Module):
|
|||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * self.act(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -442,6 +481,7 @@ class FeedForward(nn.Module):
|
|||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# attention.
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
@ -500,7 +540,8 @@ class Attention(nn.Module):
|
|||
# cosine sim attention
|
||||
self.qk_norm = qk_norm
|
||||
if qk_norm:
|
||||
scale_init_value = default(scale_init_value, -3) # if not provided, initialize as though it were sequence length of 1024
|
||||
scale_init_value = default(scale_init_value,
|
||||
-3) # if not provided, initialize as though it were sequence length of 1024
|
||||
self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
|
||||
|
||||
# talking heads
|
||||
|
@ -533,7 +574,8 @@ class Attention(nn.Module):
|
|||
self.rel_pos_bias = rel_pos_bias
|
||||
if rel_pos_bias:
|
||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
||||
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
|
||||
self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads,
|
||||
num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
|
||||
|
||||
# init output projection 0
|
||||
if zero_init_output:
|
||||
|
@ -551,7 +593,8 @@ class Attention(nn.Module):
|
|||
prev_attn=None,
|
||||
mem=None
|
||||
):
|
||||
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(context)
|
||||
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
|
||||
context)
|
||||
kv_input = default(context, x)
|
||||
|
||||
q_input = x
|
||||
|
@ -684,6 +727,7 @@ class Attention(nn.Module):
|
|||
|
||||
return self.to_out(out), intermediates
|
||||
|
||||
|
||||
class AttentionLayers(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -736,7 +780,8 @@ class AttentionLayers(nn.Module):
|
|||
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
|
||||
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
|
||||
|
||||
assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
|
||||
assert not (
|
||||
alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
|
||||
|
||||
if alibi_pos_bias:
|
||||
alibi_num_heads = default(alibi_num_heads, heads)
|
||||
|
@ -775,7 +820,8 @@ class AttentionLayers(nn.Module):
|
|||
# qk normalization
|
||||
|
||||
if use_qk_norm_attn:
|
||||
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
|
||||
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(
|
||||
qk_norm_attn_seq_len) else None
|
||||
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
|
||||
|
||||
# zero init
|
||||
|
@ -869,7 +915,8 @@ class AttentionLayers(nn.Module):
|
|||
norm_scale_shift_inp=None,
|
||||
):
|
||||
|
||||
assert not (self.cross_attend ^ (exists(context) or exists(full_context))), 'context must be passed in if cross_attend is set to True'
|
||||
assert not (self.cross_attend ^ (exists(context) or exists(
|
||||
full_context))), 'context must be passed in if cross_attend is set to True'
|
||||
assert context is None or full_context is None, 'only one of full_context or context can be provided'
|
||||
|
||||
hiddens = []
|
||||
|
@ -900,10 +947,12 @@ class AttentionLayers(nn.Module):
|
|||
x = pre_branch_norm(x, **norm_args)
|
||||
|
||||
if layer_type == 'a':
|
||||
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem)
|
||||
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
||||
prev_attn, layer_mem)
|
||||
elif layer_type == 'c':
|
||||
if exists(full_context):
|
||||
out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, None, prev_attn)
|
||||
out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
|
||||
None, prev_attn)
|
||||
else:
|
||||
out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
|
||||
elif layer_type == 'f':
|
||||
|
@ -941,20 +990,24 @@ class AttentionLayers(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(AttentionLayers):
|
||||
def __init__(self, **kwargs):
|
||||
assert 'causal' not in kwargs, 'cannot set causality on encoder'
|
||||
super().__init__(causal=False, **kwargs)
|
||||
|
||||
|
||||
class Decoder(AttentionLayers):
|
||||
def __init__(self, **kwargs):
|
||||
assert 'causal' not in kwargs, 'cannot set causality on decoder'
|
||||
super().__init__(causal=True, **kwargs)
|
||||
|
||||
|
||||
class CrossAttender(AttentionLayers):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(cross_attend=True, only_cross=True, **kwargs)
|
||||
|
||||
|
||||
class ViTransformerWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1008,6 +1061,7 @@ class ViTransformerWrapper(nn.Module):
|
|||
|
||||
return self.mlp_head(x[:, 0])
|
||||
|
||||
|
||||
class TransformerWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1034,7 +1088,8 @@ class TransformerWrapper(nn.Module):
|
|||
self.shift_mem_down = shift_mem_down
|
||||
|
||||
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
||||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
||||
|
@ -1100,6 +1155,7 @@ class TransformerWrapper(nn.Module):
|
|||
|
||||
return out
|
||||
|
||||
|
||||
class ContinuousTransformerWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1119,7 +1175,8 @@ class ContinuousTransformerWrapper(nn.Module):
|
|||
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (
|
||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
||||
|
@ -1155,6 +1212,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|||
|
||||
return out
|
||||
|
||||
|
||||
class XTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue
Block a user