From d05e162f959eaf6e3a04f8142bade229d5dd28e9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 7 Apr 2022 23:08:03 -0700 Subject: [PATCH] reformat x_transformers --- .idea/misc.xml | 2 +- .idea/mmsr.iml | 2 +- codes/models/lucidrains/x_transformers.py | 544 ++++++++++++---------- 3 files changed, 303 insertions(+), 245 deletions(-) diff --git a/.idea/misc.xml b/.idea/misc.xml index a370a267..0adf3fba 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/.idea/mmsr.iml b/.idea/mmsr.iml index b63a5458..b06487d4 100644 --- a/.idea/mmsr.iml +++ b/.idea/mmsr.iml @@ -9,7 +9,7 @@ - + diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index 727e611b..038be766 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -29,42 +29,54 @@ LayerIntermediates = namedtuple('Intermediates', [ 'attn_intermediates' ]) + # helpers def exists(val): return val is not None + def default(val, d): if exists(val): return val return d() if isfunction(d) else d + def cast_tuple(val, depth): return val if isinstance(val, tuple) else (val,) * depth + class always(): def __init__(self, val): self.val = val + def __call__(self, *args, **kwargs): return self.val + class not_equals(): def __init__(self, val): self.val = val + def __call__(self, x, *args, **kwargs): return x != self.val + class equals(): def __init__(self, val): self.val = val + def __call__(self, x, *args, **kwargs): return x == self.val + def max_neg_value(tensor): return -torch.finfo(tensor.dtype).max + def l2norm(t): - return F.normalize(t, p = 2, dim = -1) + return F.normalize(t, p=2, dim=-1) + # init helpers @@ -73,37 +85,44 @@ def init_zero_(layer): if exists(layer.bias): nn.init.constant_(layer.bias, 0.) + # keyword argument helpers def pick_and_pop(keys, d): values = list(map(lambda key: d.pop(key), keys)) return dict(zip(keys, values)) + def group_dict_by_key(cond, d): - return_val = [dict(),dict()] + return_val = [dict(), dict()] for key in d.keys(): match = bool(cond(key)) ind = int(not match) return_val[ind][key] = d[key] return (*return_val,) + def string_begins_with(prefix, str): return str.startswith(prefix) + def group_by_key_prefix(prefix, d): return group_dict_by_key(partial(string_begins_with, prefix), d) + def groupby_prefix_and_trim(prefix, d): kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) return kwargs_without_prefix, kwargs + # activations class ReluSquared(nn.Module): def forward(self, x): return F.relu(x) ** 2 + # positional embeddings class AbsolutePositionalEmbedding(nn.Module): @@ -113,25 +132,27 @@ class AbsolutePositionalEmbedding(nn.Module): self.emb = nn.Embedding(max_seq_len, dim) def forward(self, x): - n = torch.arange(x.shape[1], device = x.device) + n = torch.arange(x.shape[1], device=x.device) pos_emb = self.emb(n) pos_emb = rearrange(pos_emb, 'n d -> () n d') return pos_emb * self.scale + class FixedPositionalEmbedding(nn.Module): def __init__(self, dim): super().__init__() inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) - def forward(self, x, seq_dim = 1, offset = 0): - t = torch.arange(x.shape[seq_dim], device = x.device).type_as(self.inv_freq) + offset + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) return rearrange(emb, 'n d -> () n d') + class RelativePositionBias(nn.Module): - def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8): + def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8): super().__init__() self.scale = scale self.causal = causal @@ -140,7 +161,7 @@ class RelativePositionBias(nn.Module): self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod - def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128): + def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): ret = 0 n = -relative_position if not causal: @@ -154,7 +175,7 @@ class RelativePositionBias(nn.Module): is_small = n < max_exact val_if_large = max_exact + ( - torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) @@ -163,35 +184,38 @@ class RelativePositionBias(nn.Module): def forward(self, qk_dots): i, j, device = *qk_dots.shape[-2:], qk_dots.device - q_pos = torch.arange(i, dtype = torch.long, device = device) - k_pos = torch.arange(j, dtype = torch.long, device = device) + q_pos = torch.arange(i, dtype=torch.long, device=device) + k_pos = torch.arange(j, dtype=torch.long, device=device) rel_pos = k_pos[None, :] - q_pos[:, None] - rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance) + rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, + max_distance=self.max_distance) values = self.relative_attention_bias(rp_bucket) bias = rearrange(values, 'i j h -> () h i j') return qk_dots + (bias * self.scale) + class AlibiPositionalBias(nn.Module): def __init__(self, heads, **kwargs): super().__init__() self.heads = heads slopes = torch.Tensor(self._get_slopes(heads)) slopes = rearrange(slopes, 'h -> () h () ()') - self.register_buffer('slopes', slopes, persistent = False) - self.register_buffer('bias', None, persistent = False) + self.register_buffer('slopes', slopes, persistent=False) + self.register_buffer('bias', None, persistent=False) @staticmethod def _get_slopes(heads): def get_slopes_power_of_2(n): - start = (2**(-2**-(math.log2(n)-3))) + start = (2 ** (-2 ** -(math.log2(n) - 3))) ratio = start - return [start*ratio**i for i in range(n)] + return [start * ratio ** i for i in range(n)] if math.log2(heads).is_integer(): return get_slopes_power_of_2(heads) closest_power_of_2 = 2 ** math.floor(math.log2(heads)) - return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2] + return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ + :heads - closest_power_of_2] def forward(self, qk_dots): h, i, j, device = *qk_dots.shape[-3:], qk_dots.device @@ -199,18 +223,19 @@ class AlibiPositionalBias(nn.Module): if exists(self.bias) and self.bias.shape[-1] >= j: return qk_dots + self.bias[..., :j] - bias = torch.arange(j, device = device) + bias = torch.arange(j, device=device) bias = rearrange(bias, 'j -> () () () j') bias = bias * self.slopes num_heads_unalibied = h - bias.shape[1] bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied)) - self.register_buffer('bias', bias, persistent = False) + self.register_buffer('bias', bias, persistent=False) return qk_dots + self.bias + class LearnedAlibiPositionalBias(AlibiPositionalBias): - def __init__(self, heads, bidirectional = False): + def __init__(self, heads, bidirectional=False): super().__init__(heads) los_slopes = torch.log(self.slopes) self.learned_logslopes = nn.Parameter(los_slopes) @@ -228,10 +253,10 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias): if exists(self.bias) and self.bias.shape[-1] >= j: bias = self.bias[..., :i, :j] else: - i_arange = torch.arange(i, device = device) - j_arange = torch.arange(j, device = device) + i_arange = torch.arange(i, device=device) + j_arange = torch.arange(j, device=device) bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1') - self.register_buffer('bias', bias, persistent = False) + self.register_buffer('bias', bias, persistent=False) if self.bidirectional: past_slopes = get_slopes(self.learned_logslopes) @@ -243,6 +268,7 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias): return qk_dots + bias + class RotaryEmbedding(nn.Module): def __init__(self, dim): super().__init__() @@ -250,21 +276,24 @@ class RotaryEmbedding(nn.Module): self.register_buffer('inv_freq', inv_freq) def forward(self, max_seq_len, device): - t = torch.arange(max_seq_len, device = device).type_as(self.inv_freq) + t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq) freqs = torch.einsum('i , j -> i j', t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) return rearrange(emb, 'n d -> () () n d') + def rotate_half(x): - x = rearrange(x, '... (j d) -> ... j d', j = 2) - x1, x2 = x.unbind(dim = -2) - return torch.cat((-x2, x1), dim = -1) + x = rearrange(x, '... (j d) -> ... j d', j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + def apply_rotary_pos_emb(t, freqs): seq_len = t.shape[-2] freqs = freqs[:, :, -seq_len:] return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) + # norms class Scale(nn.Module): @@ -282,6 +311,7 @@ class Scale(nn.Module): return (scale_fn(out[0]), *out[1:]) + class Rezero(nn.Module): def __init__(self, fn): super().__init__() @@ -297,49 +327,53 @@ class Rezero(nn.Module): return (rezero_fn(out[0]), *out[1:]) + class ScaleNorm(nn.Module): - def __init__(self, dim, eps = 1e-5): + def __init__(self, dim, eps=1e-5): super().__init__() self.scale = dim ** -0.5 self.eps = eps self.g = nn.Parameter(torch.ones(1)) def forward(self, x): - norm = torch.norm(x, dim = -1, keepdim = True) * self.scale - return x / norm.clamp(min = self.eps) * self.g + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + class RMSNorm(nn.Module): - def __init__(self, dim, eps = 1e-8): + def __init__(self, dim, eps=1e-8): super().__init__() self.scale = dim ** -0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): - norm = torch.norm(x, dim = -1, keepdim = True) * self.scale - return x / norm.clamp(min = self.eps) * self.g + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + class RMSScaleShiftNorm(nn.Module): - def __init__(self, dim, eps = 1e-8): + def __init__(self, dim, eps=1e-8): super().__init__() self.scale = dim ** -0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) - self.scale_shift_process = nn.Linear(dim*2, dim*2) + self.scale_shift_process = nn.Linear(dim * 2, dim * 2) def forward(self, x, norm_scale_shift_inp): - norm = torch.norm(x, dim = -1, keepdim = True) * self.scale - norm = x / norm.clamp(min = self.eps) * self.g + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + norm = x / norm.clamp(min=self.eps) * self.g ss_emb = self.scale_shift_process(norm_scale_shift_inp) scale, shift = torch.chunk(ss_emb, 2, dim=1) h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) return h + # residual and residual gates class Residual(nn.Module): - def __init__(self, dim, scale_residual = False): + def __init__(self, dim, scale_residual=False): super().__init__() self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None @@ -349,8 +383,9 @@ class Residual(nn.Module): return x + residual + class GRUGating(nn.Module): - def __init__(self, dim, scale_residual = False): + def __init__(self, dim, scale_residual=False): super().__init__() self.gru = nn.GRUCell(dim, dim) self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None @@ -366,16 +401,18 @@ class GRUGating(nn.Module): return gated_output.reshape_as(x) + # token shifting -def shift(t, amount, mask = None): +def shift(t, amount, mask=None): if amount == 0: return t if exists(mask): t = t.masked_fill(~mask[..., None], 0.) - return F.pad(t, (0, 0, amount, -amount), value = 0.) + return F.pad(t, (0, 0, amount, -amount), value=0.) + class ShiftTokens(nn.Module): def __init__(self, shifts, fn): @@ -388,12 +425,13 @@ class ShiftTokens(nn.Module): shifts = self.shifts segments = len(shifts) feats_per_shift = x.shape[-1] // segments - splitted = x.split(feats_per_shift, dim = -1) + splitted = x.split(feats_per_shift, dim=-1) segments_to_shift, rest = splitted[:segments], splitted[segments:] - segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts))) - x = torch.cat((*segments_to_shift, *rest), dim = -1) + segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts))) + x = torch.cat((*segments_to_shift, *rest), dim=-1) return self.fn(x, **kwargs) + # feedforward class GLU(nn.Module): @@ -403,20 +441,21 @@ class GLU(nn.Module): self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): - x, gate = self.proj(x).chunk(2, dim = -1) + x, gate = self.proj(x).chunk(2, dim=-1) return x * self.act(gate) + class FeedForward(nn.Module): def __init__( - self, - dim, - dim_out = None, - mult = 4, - glu = False, - relu_squared = False, - post_act_ln = False, - dropout = 0., - zero_init_output = False + self, + dim, + dim_out=None, + mult=4, + glu=False, + relu_squared=False, + post_act_ln=False, + dropout=0., + zero_init_output=False ): super().__init__() inner_dim = int(dim * mult) @@ -442,32 +481,33 @@ class FeedForward(nn.Module): def forward(self, x): return self.net(x) + # attention. class Attention(nn.Module): def __init__( - self, - dim, - dim_head = DEFAULT_DIM_HEAD, - heads = 8, - causal = False, - talking_heads = False, - head_scale = False, - collab_heads = False, - collab_compression = .3, - sparse_topk = None, - use_entmax15 = False, - num_mem_kv = 0, - dropout = 0., - on_attn = False, - gate_values = False, - zero_init_output = False, - max_attend_past = None, - qk_norm = False, - scale_init_value = None, - rel_pos_bias = False, - rel_pos_num_buckets = 32, - rel_pos_max_distance = 128, + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + talking_heads=False, + head_scale=False, + collab_heads=False, + collab_compression=.3, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False, + gate_values=False, + zero_init_output=False, + max_attend_past=None, + qk_norm=False, + scale_init_value=None, + rel_pos_bias=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, ): super().__init__() self.scale = dim_head ** -0.5 @@ -484,9 +524,9 @@ class Attention(nn.Module): qk_dim = int(collab_compression * qk_dim) self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim)) - self.to_q = nn.Linear(dim, qk_dim, bias = False) - self.to_k = nn.Linear(dim, qk_dim, bias = False) - self.to_v = nn.Linear(dim, v_dim, bias = False) + self.to_q = nn.Linear(dim, qk_dim, bias=False) + self.to_k = nn.Linear(dim, qk_dim, bias=False) + self.to_v = nn.Linear(dim, v_dim, bias=False) self.dropout = nn.Dropout(dropout) @@ -500,7 +540,8 @@ class Attention(nn.Module): # cosine sim attention self.qk_norm = qk_norm if qk_norm: - scale_init_value = default(scale_init_value, -3) # if not provided, initialize as though it were sequence length of 1024 + scale_init_value = default(scale_init_value, + -3) # if not provided, initialize as though it were sequence length of 1024 self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value) # talking heads @@ -533,25 +574,27 @@ class Attention(nn.Module): self.rel_pos_bias = rel_pos_bias if rel_pos_bias: assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' - self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance) + self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads, + num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance) # init output projection 0 if zero_init_output: init_zero_(self.to_out) def forward( - self, - x, - context = None, - mask = None, - context_mask = None, - attn_mask = None, - sinusoidal_emb = None, - rotary_pos_emb = None, - prev_attn = None, - mem = None + self, + x, + context=None, + mask=None, + context_mask=None, + attn_mask=None, + sinusoidal_emb=None, + rotary_pos_emb=None, + prev_attn=None, + mem=None ): - b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(context) + b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists( + context) kv_input = default(context, x) q_input = x @@ -559,13 +602,13 @@ class Attention(nn.Module): v_input = kv_input if exists(mem): - k_input = torch.cat((mem, k_input), dim = -2) - v_input = torch.cat((mem, v_input), dim = -2) + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) if exists(sinusoidal_emb): # in shortformer, the query would start at a position offset depending on the past cached memory offset = k_input.shape[-2] - q_input.shape[-2] - q_input = q_input + sinusoidal_emb(q_input, offset = offset) + q_input = q_input + sinusoidal_emb(q_input, offset=offset) k_input = k_input + sinusoidal_emb(k_input) q = self.to_q(q_input) @@ -573,40 +616,40 @@ class Attention(nn.Module): v = self.to_v(v_input) if not collab_heads: - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) else: q = einsum('b i d, h d -> b h i d', q, self.collab_mixing) k = rearrange(k, 'b n d -> b () n d') - v = rearrange(v, 'b n (h d) -> b h n d', h = h) + v = rearrange(v, 'b n (h d) -> b h n d', h=h) if exists(rotary_pos_emb) and not has_context: l = rotary_pos_emb.shape[-1] (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) - q, k, v = map(lambda t: torch.cat(t, dim = -1), ((ql, qr), (kl, kr), (vl, vr))) + q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) input_mask = None if any(map(exists, (mask, context_mask))): - q_mask = default(mask, lambda: torch.ones((b, n), device = device).bool()) + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) k_mask = q_mask if not exists(context) else context_mask - k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device = device).bool()) + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) q_mask = rearrange(q_mask, 'b i -> b () i ()') k_mask = rearrange(k_mask, 'b j -> b () () j') input_mask = q_mask * k_mask if self.num_mem_kv > 0: - mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v)) - k = torch.cat((mem_k, k), dim = -2) - v = torch.cat((mem_v, v), dim = -2) + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) if exists(input_mask): - input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value = True) + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) if collab_heads: k = k.expand(-1, h, -1, -1) if self.qk_norm: q, k = map(l2norm, (q, k)) - scale = 1 / (self.scale.exp().clamp(min = 1e-2)) + scale = 1 / (self.scale.exp().clamp(min=1e-2)) dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale mask_value = max_neg_value(dots) @@ -636,8 +679,8 @@ class Attention(nn.Module): if exists(self.max_attend_past): i, j = dots.shape[-2:] - range_q = torch.arange(j - i, j, device = device) - range_k = torch.arange(j, device = device) + range_q = torch.arange(j - i, j, device=device) + range_k = torch.arange(j, device=device) dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j') mask = dist > self.max_attend_past dots.masked_fill_(mask, mask_value) @@ -645,20 +688,20 @@ class Attention(nn.Module): if self.causal: i, j = dots.shape[-2:] - r = torch.arange(i, device = device) + r = torch.arange(i, device=device) mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') - mask = F.pad(mask, (j - i, 0), value = False) + mask = F.pad(mask, (j - i, 0), value=False) dots.masked_fill_(mask, mask_value) del mask if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: - top, _ = dots.topk(self.sparse_topk, dim = -1) + top, _ = dots.topk(self.sparse_topk, dim=-1) vk = top[..., -1].unsqueeze(-1).expand_as(dots) mask = dots < vk dots.masked_fill_(mask, mask_value) del mask - attn = self.attn_fn(dots, dim = -1) + attn = self.attn_fn(dots, dim=-1) post_softmax_attn = attn.clone() attn = self.dropout(attn) @@ -678,46 +721,47 @@ class Attention(nn.Module): out = out * gates.sigmoid() intermediates = Intermediates( - pre_softmax_attn = pre_softmax_attn, - post_softmax_attn = post_softmax_attn + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn ) return self.to_out(out), intermediates + class AttentionLayers(nn.Module): def __init__( - self, - dim, - depth, - heads = 8, - causal = False, - cross_attend = False, - only_cross = False, - use_scalenorm = False, - use_rms_scaleshift_norm = False, - use_rmsnorm = False, - use_rezero = False, - alibi_pos_bias = False, - alibi_num_heads = None, - alibi_learned = False, - position_infused_attn = False, - rotary_pos_emb = False, - rotary_emb_dim = None, - custom_layers = None, - sandwich_coef = None, - par_ratio = None, - residual_attn = False, - cross_residual_attn = False, - macaron = False, - pre_norm = True, - gate_residual = False, - scale_residual = False, - shift_tokens = 0, - sandwich_norm = False, - use_qk_norm_attn = False, - qk_norm_attn_seq_len = None, - zero_init_branch_output = False, - **kwargs + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rms_scaleshift_norm=False, + use_rmsnorm=False, + use_rezero=False, + alibi_pos_bias=False, + alibi_num_heads=None, + alibi_learned=False, + position_infused_attn=False, + rotary_pos_emb=False, + rotary_emb_dim=None, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + scale_residual=False, + shift_tokens=0, + sandwich_norm=False, + use_qk_norm_attn=False, + qk_norm_attn_seq_len=None, + zero_init_branch_output=False, + **kwargs ): super().__init__() ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) @@ -736,13 +780,14 @@ class AttentionLayers(nn.Module): rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None - assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' + assert not ( + alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' if alibi_pos_bias: alibi_num_heads = default(alibi_num_heads, heads) assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias - self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, bidirectional = not causal) + self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal) else: self.rel_pos = None @@ -775,14 +820,15 @@ class AttentionLayers(nn.Module): # qk normalization if use_qk_norm_attn: - attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None + attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists( + qk_norm_attn_seq_len) else None attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value} # zero init if zero_init_branch_output: - attn_kwargs = {**attn_kwargs, 'zero_init_output': True} - ff_kwargs = {**ff_kwargs, 'zero_init_output': True} + attn_kwargs = {**attn_kwargs, 'zero_init_output': True} + ff_kwargs = {**ff_kwargs, 'zero_init_output': True} # calculate layer block order @@ -792,7 +838,7 @@ class AttentionLayers(nn.Module): par_depth = depth * len(default_block) assert 1 < par_ratio <= par_depth, 'par ratio out of range' default_block = tuple(filter(not_equals('f'), default_block)) - par_attn = par_depth // par_ratio + par_attn = par_depth // par_ratio depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper par_width = (depth_cut + depth_cut // par_attn) // par_attn assert len(default_block) <= par_width, 'default block is too large for par_ratio' @@ -818,9 +864,9 @@ class AttentionLayers(nn.Module): is_last_layer = ind == (len(self.layer_types) - 1) if layer_type == 'a': - layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs) + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) elif layer_type == 'c': - layer = Attention(dim, heads = heads, **attn_kwargs) + layer = Attention(dim, heads=heads, **attn_kwargs) elif layer_type == 'f': layer = FeedForward(dim, **ff_kwargs) layer = layer if not macaron else Scale(0.5, layer) @@ -836,7 +882,7 @@ class AttentionLayers(nn.Module): layer = branch_fn(layer) residual_fn = GRUGating if gate_residual else Residual - residual = residual_fn(dim, scale_residual = scale_residual) + residual = residual_fn(dim, scale_residual=scale_residual) layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c') @@ -857,19 +903,20 @@ class AttentionLayers(nn.Module): ])) def forward( - self, - x, - context = None, - full_context = None, # for passing a list of hidden states from an encoder - mask = None, - context_mask = None, - attn_mask = None, - mems = None, - return_hiddens = False, - norm_scale_shift_inp = None, + self, + x, + context=None, + full_context=None, # for passing a list of hidden states from an encoder + mask=None, + context_mask=None, + attn_mask=None, + mems=None, + return_hiddens=False, + norm_scale_shift_inp=None, ): - - assert not (self.cross_attend ^ (exists(context) or exists(full_context))), 'context must be passed in if cross_attend is set to True' + + assert not (self.cross_attend ^ (exists(context) or exists( + full_context))), 'context must be passed in if cross_attend is set to True' assert context is None or full_context is None, 'only one of full_context or context can be provided' hiddens = [] @@ -900,12 +947,14 @@ class AttentionLayers(nn.Module): x = pre_branch_norm(x, **norm_args) if layer_type == 'a': - out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem) + out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, + prev_attn, layer_mem) elif layer_type == 'c': if exists(full_context): - out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, None, prev_attn) + out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, + None, prev_attn) else: - out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn) + out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn) elif layer_type == 'f': out = checkpoint(block, x) @@ -933,38 +982,42 @@ class AttentionLayers(nn.Module): if return_hiddens: intermediates = LayerIntermediates( - hiddens = hiddens, - attn_intermediates = intermediates + hiddens=hiddens, + attn_intermediates=intermediates ) return x, intermediates return x + class Encoder(AttentionLayers): def __init__(self, **kwargs): assert 'causal' not in kwargs, 'cannot set causality on encoder' - super().__init__(causal = False, **kwargs) + super().__init__(causal=False, **kwargs) + class Decoder(AttentionLayers): def __init__(self, **kwargs): assert 'causal' not in kwargs, 'cannot set causality on decoder' - super().__init__(causal = True, **kwargs) + super().__init__(causal=True, **kwargs) + class CrossAttender(AttentionLayers): def __init__(self, **kwargs): - super().__init__(cross_attend = True, only_cross = True, **kwargs) + super().__init__(cross_attend=True, only_cross=True, **kwargs) + class ViTransformerWrapper(nn.Module): def __init__( - self, - *, - image_size, - patch_size, - attn_layers, - num_classes = None, - dropout = 0., - emb_dropout = 0. + self, + *, + image_size, + patch_size, + attn_layers, + num_classes=None, + dropout=0., + emb_dropout=0. ): super().__init__() assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder' @@ -982,20 +1035,20 @@ class ViTransformerWrapper(nn.Module): self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) - self.mlp_head = FeedForward(dim, dim_out = num_classes, dropout = dropout) if exists(num_classes) else None + self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None def forward( - self, - img, - return_embeddings = False + self, + img, + return_embeddings=False ): p = self.patch_size - x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) + x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) x = self.patch_to_embedding(x) b, n, _ = x.shape - cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embedding[:, :(n + 1)] x = self.dropout(x) @@ -1008,20 +1061,21 @@ class ViTransformerWrapper(nn.Module): return self.mlp_head(x[:, 0]) + class TransformerWrapper(nn.Module): def __init__( - self, - *, - num_tokens, - max_seq_len, - attn_layers, - emb_dim = None, - max_mem_len = 0., - shift_mem_down = 0, - emb_dropout = 0., - num_memory_tokens = None, - tie_embedding = False, - use_pos_emb = True + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + shift_mem_down=0, + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True ): super().__init__() assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' @@ -1034,7 +1088,8 @@ class TransformerWrapper(nn.Module): self.shift_mem_down = shift_mem_down self.token_emb = nn.Embedding(num_tokens, emb_dim) - self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) self.emb_dropout = nn.Dropout(emb_dropout) self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() @@ -1055,14 +1110,14 @@ class TransformerWrapper(nn.Module): nn.init.kaiming_normal_(self.token_emb.weight) def forward( - self, - x, - return_embeddings = False, - mask = None, - return_hiddens = False, - return_attn = False, - mems = None, - **kwargs + self, + x, + return_embeddings=False, + mask=None, + return_hiddens=False, + return_attn=False, + mems=None, + **kwargs ): b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens x = self.token_emb(x) @@ -1072,18 +1127,18 @@ class TransformerWrapper(nn.Module): x = self.project_emb(x) if num_mem > 0: - mem = repeat(self.memory_tokens, 'n d -> b n d', b = b) - x = torch.cat((mem, x), dim = 1) + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) # auto-handle masking after appending memory tokens if exists(mask): - mask = F.pad(mask, (num_mem, 0), value = True) + mask = F.pad(mask, (num_mem, 0), value=True) if self.shift_mem_down and exists(mems): mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:] mems = [*mems_r, *mems_l] - x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs) + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) x = self.norm(x) mem, x = x[:, :num_mem], x[:, num_mem:] @@ -1100,17 +1155,18 @@ class TransformerWrapper(nn.Module): return out + class ContinuousTransformerWrapper(nn.Module): def __init__( - self, - *, - max_seq_len, - attn_layers, - dim_in = None, - dim_out = None, - emb_dim = None, - emb_dropout = 0., - use_pos_emb = True + self, + *, + max_seq_len, + attn_layers, + dim_in=None, + dim_out=None, + emb_dim=None, + emb_dropout=0., + use_pos_emb=True ): super().__init__() assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' @@ -1119,7 +1175,8 @@ class ContinuousTransformerWrapper(nn.Module): self.max_seq_len = max_seq_len - self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) self.emb_dropout = nn.Dropout(emb_dropout) self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() @@ -1130,13 +1187,13 @@ class ContinuousTransformerWrapper(nn.Module): self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() def forward( - self, - x, - return_embeddings = False, - mask = None, - return_attn = False, - mems = None, - **kwargs + self, + x, + return_embeddings=False, + mask=None, + return_attn=False, + mems=None, + **kwargs ): b, n, _, device = *x.shape, x.device @@ -1144,7 +1201,7 @@ class ContinuousTransformerWrapper(nn.Module): x = x + self.pos_emb(x) x = self.emb_dropout(x) - x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs) + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) x = self.norm(x) out = self.project_out(x) if not return_embeddings else x @@ -1155,13 +1212,14 @@ class ContinuousTransformerWrapper(nn.Module): return out + class XTransformer(nn.Module): def __init__( - self, - *, - dim, - tie_token_emb = False, - **kwargs + self, + *, + dim, + tie_token_emb=False, + **kwargs ): super().__init__() enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs) @@ -1179,12 +1237,12 @@ class XTransformer(nn.Module): self.encoder = TransformerWrapper( **enc_transformer_kwargs, - attn_layers = Encoder(dim = dim, **enc_kwargs) + attn_layers=Encoder(dim=dim, **enc_kwargs) ) self.decoder = TransformerWrapper( **dec_transformer_kwargs, - attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs) + attn_layers=Decoder(dim=dim, cross_attend=True, **dec_kwargs) ) if tie_token_emb: @@ -1193,11 +1251,11 @@ class XTransformer(nn.Module): self.decoder = AutoregressiveWrapper(self.decoder) @torch.no_grad() - def generate(self, seq_in, seq_out_start, seq_len, src_mask = None, src_attn_mask = None, **kwargs): - encodings = self.encoder(seq_in, mask = src_mask, attn_mask = src_attn_mask, return_embeddings = True) - return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = src_mask, **kwargs) + def generate(self, seq_in, seq_out_start, seq_len, src_mask=None, src_attn_mask=None, **kwargs): + encodings = self.encoder(seq_in, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True) + return self.decoder.generate(seq_out_start, seq_len, context=encodings, context_mask=src_mask, **kwargs) - def forward(self, src, tgt, src_mask = None, tgt_mask = None, src_attn_mask = None): - enc = self.encoder(src, mask = src_mask, attn_mask = src_attn_mask, return_embeddings = True) - out = self.decoder(tgt, context = enc, mask = tgt_mask, context_mask = src_mask) + def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_attn_mask=None): + enc = self.encoder(src, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True) + out = self.decoder(tgt, context=enc, mask=tgt_mask, context_mask=src_mask) return out