From c075fe72e219b34b302be77aee6592502312bf5c Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 9 Jan 2022 22:10:07 -0700 Subject: [PATCH] import performer repo --- codes/models/lucidrains/performer/__init__.py | 3 + .../performer/autoregressive_wrapper.py | 107 +++ .../lucidrains/performer/performer_enc_dec.py | 79 +++ .../lucidrains/performer/performer_pytorch.py | 671 ++++++++++++++++++ .../models/lucidrains/performer/reversible.py | 156 ++++ 5 files changed, 1016 insertions(+) create mode 100644 codes/models/lucidrains/performer/__init__.py create mode 100644 codes/models/lucidrains/performer/autoregressive_wrapper.py create mode 100644 codes/models/lucidrains/performer/performer_enc_dec.py create mode 100644 codes/models/lucidrains/performer/performer_pytorch.py create mode 100644 codes/models/lucidrains/performer/reversible.py diff --git a/codes/models/lucidrains/performer/__init__.py b/codes/models/lucidrains/performer/__init__.py new file mode 100644 index 00000000..0dfe9687 --- /dev/null +++ b/codes/models/lucidrains/performer/__init__.py @@ -0,0 +1,3 @@ +from performer_pytorch import PerformerLM, Performer, FastAttention, SelfAttention, CrossAttention, ProjectionUpdater +from autoregressive_wrapper import AutoregressiveWrapper +from performer_enc_dec import PerformerEncDec diff --git a/codes/models/lucidrains/performer/autoregressive_wrapper.py b/codes/models/lucidrains/performer/autoregressive_wrapper.py new file mode 100644 index 00000000..fa28b771 --- /dev/null +++ b/codes/models/lucidrains/performer/autoregressive_wrapper.py @@ -0,0 +1,107 @@ +from functools import partial +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +def exists(val): + return val is not None + +def top_p(logits, thres = 0.9): + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + sorted_indices_to_remove = cum_probs > (1 - thres) + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = 0 + + sorted_logits[sorted_indices_to_remove] = float('-inf') + return sorted_logits.scatter(1, sorted_indices, sorted_logits) + +def top_k(logits, thres = 0.9): + k = int((1 - thres) * logits.shape[-1]) + val, ind = torch.topk(logits, k) + probs = torch.full_like(logits, float('-inf')) + probs.scatter_(1, ind, val) + return probs + +def repetition_penalty_fn(logits, ctx, theta=1.2): + w = torch.ones(logits.shape[-1], dtype=torch.float, device=logits.device) + for i in torch.unique(ctx): + w[i] = theta + return logits/w + +class AutoregressiveWrapper(nn.Module): + def __init__(self, net, ignore_index = 0, pad_value = 0): + super().__init__() + self.pad_value = pad_value + self.ignore_index = ignore_index + + self.net = net + self.max_seq_len = net.max_seq_len + + @torch.no_grad() + def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, repetition_penalty=1.0, repetition_penalty_ctx=32, **kwargs): + was_training = self.net.training + num_dims = len(start_tokens.shape) + + if num_dims == 1: + start_tokens = start_tokens[None, :] + + b, t = start_tokens.shape + + self.net.eval() + out = start_tokens + input_mask = kwargs.pop('mask', None) + + if input_mask is None: + input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device) + + # in case of conditional generation, if enc_mask is not provided use the correct context_mask + context_mask = kwargs.pop('context_mask', None) + + if 'context' in kwargs and not exists(context_mask): + context = kwargs['context'] + context_mask = torch.full(context.shape[:2], True, dtype=torch.bool, device=out.device) + + kwargs.update(context_mask = context_mask) + + for _ in range(seq_len): + x = out[:, -self.max_seq_len:] + input_mask = input_mask[:, -self.max_seq_len:] + logits = self.net(x, mask=input_mask, **kwargs)[:, -1, :] + if repetition_penalty > 1.0: + logits = repetition_penalty_fn(logits, out[-repetition_penalty_ctx:], theta=repetition_penalty) + filtered_logits = filter_logits_fn(logits, thres = filter_thres) + probs = F.softmax(filtered_logits / temperature, dim=-1) + sample = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + input_mask = F.pad(input_mask, (0, 1), value=True) + + if eos_token is not None and (sample == eos_token).all(): + break + + out = out[:, t:] + + if num_dims == 1: + out = out.squeeze(0) + + self.net.train(was_training) + return out + + def forward(self, x, **kwargs): + xi = x[:, :-1] + xo = x[:, 1:] + + # help auto-solve an area of confusion around input masks in auto-regressive + # if user supplies a mask that is only off by one from the source sequence, resolve it for them + mask = kwargs.pop('mask', None) + if mask is not None and mask.shape[1] == x.shape[1]: + mask = mask[:, :-1] + kwargs.update(mask = mask) + + out = self.net(xi, **kwargs) + + loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index) + return loss diff --git a/codes/models/lucidrains/performer/performer_enc_dec.py b/codes/models/lucidrains/performer/performer_enc_dec.py new file mode 100644 index 00000000..192f63a0 --- /dev/null +++ b/codes/models/lucidrains/performer/performer_enc_dec.py @@ -0,0 +1,79 @@ +import re +import torch +from torch import nn +from performer_pytorch import PerformerLM +from autoregressive_wrapper import AutoregressiveWrapper + +ENC_PREFIX = 'enc_' +DEC_PREFIX = 'dec_' + +def group_dict_by_key(cond, d): + return_val = [dict(),dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + +def string_begins_with(prefix, str): + return bool(re.match(f'^{prefix}', str)) + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(lambda x: string_begins_with(prefix, x), d) + +def group_by_key_prefix_and_remove_prefix(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: string_begins_with(prefix, x), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + +def extract_enc_dec_kwargs(kwargs): + enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(ENC_PREFIX, kwargs) + dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(DEC_PREFIX, kwargs) + return enc_kwargs, dec_kwargs, kwargs + +def extract_and_set_enc_dec_kwargs(kwargs): + enc_kwargs, dec_kwargs, kwargs = extract_enc_dec_kwargs(kwargs) + if 'mask' in enc_kwargs: + dec_kwargs.setdefault('context_mask', enc_kwargs['mask']) + return enc_kwargs, dec_kwargs, kwargs + +class PerformerEncDec(nn.Module): + def __init__( + self, + dim, + ignore_index = 0, + pad_value = 0, + tie_token_embeds = False, + no_projection = False, + **kwargs + ): + super().__init__() + enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs) + + assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder' + + enc_kwargs['dim'] = dec_kwargs['dim'] = dim + enc_kwargs['no_projection'] = dec_kwargs['no_projection'] = no_projection + + dec_kwargs['causal'] = True + dec_kwargs['cross_attend'] = True + + enc = PerformerLM(**enc_kwargs) + dec = PerformerLM(**dec_kwargs) + + if tie_token_embeds: + enc.token_emb = dec.token_emb + + self.enc = enc + self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index, pad_value = pad_value) + + @torch.no_grad() + def generate(self, seq_in, seq_out_start, seq_len, **kwargs): + enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs) + encodings = self.enc(seq_in, return_encodings = True, **enc_kwargs) + return self.dec.generate(seq_out_start, seq_len, context = encodings, **{**dec_kwargs, **kwargs}) + + def forward(self, seq_in, seq_out, enc_mask = None, **kwargs): + enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs) + encodings = self.enc(seq_in, mask = enc_mask, return_encodings = True, **enc_kwargs) + return self.dec(seq_out, context = encodings, context_mask = enc_mask, **dec_kwargs) \ No newline at end of file diff --git a/codes/models/lucidrains/performer/performer_pytorch.py b/codes/models/lucidrains/performer/performer_pytorch.py new file mode 100644 index 00000000..836ed03f --- /dev/null +++ b/codes/models/lucidrains/performer/performer_pytorch.py @@ -0,0 +1,671 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn +from torch.cuda.amp import autocast +from einops import rearrange, repeat + +from functools import partial +from contextlib import contextmanager + +from local_attention import LocalAttention +from axial_positional_embedding import AxialPositionalEmbedding +from reversible import ReversibleSequence, SequentialSequence + +from distutils.version import LooseVersion + +TORCH_GE_1_8_0 = LooseVersion(torch.__version__) >= LooseVersion('1.8.0') + +try: + from apex import amp + APEX_AVAILABLE = True +except: + APEX_AVAILABLE = False + +# helpers + +def exists(val): + return val is not None + +def empty(tensor): + return tensor.numel() == 0 + +def default(val, d): + return val if exists(val) else d + +@contextmanager +def null_context(): + yield + +def cast_tuple(val): + return (val,) if not isinstance(val, tuple) else val + +def get_module_device(module): + return next(module.parameters()).device + +def find_modules(nn_module, type): + return [module for module in nn_module.modules() if isinstance(module, type)] + +class Always(nn.Module): + def __init__(self, val): + super().__init__() + self.val = val + + def forward(self, *args, **kwargs): + return self.val + +# token shifting helper and classes + +def shift(t, amount, mask = None): + if amount == 0: + return t + + if exists(mask): + t = t.masked_fill(~mask[..., None], 0.) + + return F.pad(t, (0, 0, amount, -amount), value = 0.) + +class PreShiftTokens(nn.Module): + def __init__(self, shifts, fn): + super().__init__() + self.fn = fn + self.shifts = tuple(shifts) + + def forward(self, x, **kwargs): + mask = kwargs.get('mask', None) + shifts = self.shifts + segments = len(shifts) + feats_per_shift = x.shape[-1] // segments + splitted = x.split(feats_per_shift, dim = -1) + segments_to_shift, rest = splitted[:segments], splitted[segments:] + segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts))) + x = torch.cat((*segments_to_shift, *rest), dim = -1) + return self.fn(x, **kwargs) + +# kernel functions + +# transcribed from jax to pytorch from +# https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py + +def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None): + b, h, *_ = data.shape + + data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. + + ratio = (projection_matrix.shape[0] ** -0.5) + + projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) + projection = projection.type_as(data) + + data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) + + diag_data = data ** 2 + diag_data = torch.sum(diag_data, dim=-1) + diag_data = (diag_data / 2.0) * (data_normalizer ** 2) + diag_data = diag_data.unsqueeze(dim=-1) + + if is_query: + data_dash = ratio * ( + torch.exp(data_dash - diag_data - + torch.amax(data_dash, dim=-1, keepdim=True)) + eps) + else: + data_dash = ratio * ( + torch.exp(data_dash - diag_data - torch.amax(data_dash, dim=(-1, -2), keepdim=True)) + eps) + + return data_dash.type_as(data) + +def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None): + b, h, *_ = data.shape + + data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. + + if projection_matrix is None: + return kernel_fn(data_normalizer * data) + kernel_epsilon + + projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) + projection = projection.type_as(data) + + data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) + + data_prime = kernel_fn(data_dash) + kernel_epsilon + return data_prime.type_as(data) + +def orthogonal_matrix_chunk(cols, device = None): + unstructured_block = torch.randn((cols, cols), device = device) + if TORCH_GE_1_8_0: + q, r = torch.linalg.qr(unstructured_block.cpu(), mode = 'reduced') + else: + q, r = torch.qr(unstructured_block.cpu(), some = True) + q, r = map(lambda t: t.to(device), (q, r)) + return q.t() + +def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = None): + nb_full_blocks = int(nb_rows / nb_columns) + + block_list = [] + + for _ in range(nb_full_blocks): + q = orthogonal_matrix_chunk(nb_columns, device = device) + block_list.append(q) + + remaining_rows = nb_rows - nb_full_blocks * nb_columns + if remaining_rows > 0: + q = orthogonal_matrix_chunk(nb_columns, device = device) + block_list.append(q[:remaining_rows]) + + final_matrix = torch.cat(block_list) + + if scaling == 0: + multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1) + elif scaling == 1: + multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device) + else: + raise ValueError(f'Invalid scaling {scaling}') + + return torch.diag(multiplier) @ final_matrix + +# linear attention classes with softmax kernel + +# non-causal linear attention +def linear_attention(q, k, v): + k_cumsum = k.sum(dim = -2) + D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) + context = torch.einsum('...nd,...ne->...de', k, v) + out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv) + return out + +# efficient causal linear attention, created by EPFL +# TODO: rewrite EPFL's CUDA kernel to do mixed precision and remove half to float conversion and back +def causal_linear_attention(q, k, v, eps = 1e-6): + from fast_transformers.causal_product import CausalDotProduct + autocast_enabled = torch.is_autocast_enabled() + is_half = isinstance(q, torch.cuda.HalfTensor) + assert not is_half or APEX_AVAILABLE, 'half tensors can only be used if nvidia apex is available' + cuda_context = null_context if not autocast_enabled else partial(autocast, enabled = False) + + causal_dot_product_fn = amp.float_function(CausalDotProduct.apply) if is_half else CausalDotProduct.apply + + k_cumsum = k.cumsum(dim=-2) + eps + D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q)) + + with cuda_context(): + if autocast_enabled: + q, k, v = map(lambda t: t.float(), (q, k, v)) + + out = causal_dot_product_fn(q, k, v) + + out = torch.einsum('...nd,...n->...nd', out, D_inv) + return out + +# inefficient causal linear attention, without cuda code, for reader's reference +# not being used +def causal_linear_attention_noncuda(q, k, v, chunk_size = 128, eps = 1e-6): + last_k_cumsum = 0 + last_context_cumsum = 0 + outs = [] + + for q, k, v in zip(*map(lambda t: t.chunk(chunk_size, dim = -2), (q, k, v))): + k_cumsum = last_k_cumsum + k.cumsum(dim=-2) + + D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q) + eps) + context = torch.einsum('...nd,...ne->...nde', k, v) + context_cumsum = last_context_cumsum + context.cumsum(dim=-3) + out = torch.einsum('...nde,...nd,...n->...ne', context_cumsum, q, D_inv) + + last_k_cumsum = k_cumsum[:, :, -1:] + last_context_cumsum = context_cumsum[:, :, -1:] + outs.append(out) + + return torch.cat(outs, dim = -2) + +class FastAttention(nn.Module): + def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), no_projection = False): + super().__init__() + nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) + + self.dim_heads = dim_heads + self.nb_features = nb_features + self.ortho_scaling = ortho_scaling + + self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling) + projection_matrix = self.create_projection() + self.register_buffer('projection_matrix', projection_matrix) + + self.generalized_attention = generalized_attention + self.kernel_fn = kernel_fn + + # if this is turned on, no projection will be used + # queries and keys will be softmax-ed as in the original efficient attention paper + self.no_projection = no_projection + + self.causal = causal + if causal: + try: + import fast_transformers.causal_product.causal_product_cuda + self.causal_linear_fn = partial(causal_linear_attention) + except ImportError: + print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version') + self.causal_linear_fn = causal_linear_attention_noncuda + + @torch.no_grad() + def redraw_projection_matrix(self, device): + projections = self.create_projection(device = device) + self.projection_matrix.copy_(projections) + del projections + + def forward(self, q, k, v): + device = q.device + + if self.no_projection: + q = q.softmax(dim = -1) + k = torch.exp(k) if self.causal else k.softmax(dim = -2) + + elif self.generalized_attention: + create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device) + q, k = map(create_kernel, (q, k)) + + else: + create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device) + q = create_kernel(q, is_query = True) + k = create_kernel(k, is_query = False) + + attn_fn = linear_attention if not self.causal else self.causal_linear_fn + out = attn_fn(q, k, v) + return out + +# a module for keeping track of when to update the projections + +class ProjectionUpdater(nn.Module): + def __init__(self, instance, feature_redraw_interval): + super().__init__() + self.instance = instance + self.feature_redraw_interval = feature_redraw_interval + self.register_buffer('calls_since_last_redraw', torch.tensor(0)) + + def fix_projections_(self): + self.feature_redraw_interval = None + + def redraw_projections(self): + model = self.instance + + if not self.training: + return + + if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval: + device = get_module_device(model) + + fast_attentions = find_modules(model, FastAttention) + for fast_attention in fast_attentions: + fast_attention.redraw_projection_matrix(device) + + self.calls_since_last_redraw.zero_() + return + + self.calls_since_last_redraw += 1 + + def forward(self, x): + raise NotImplemented + +# classes + +class ReZero(nn.Module): + def __init__(self, fn): + super().__init__() + self.g = nn.Parameter(torch.tensor(1e-3)) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) * self.g + +class PreScaleNorm(nn.Module): + def __init__(self, dim, fn, eps=1e-5): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.ones(1)) + self.eps = eps + + def forward(self, x, **kwargs): + n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps) + x = x / n * self.g + return self.fn(x, **kwargs) + +class PreLayerNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class Chunk(nn.Module): + def __init__(self, chunks, fn, along_dim = -1): + super().__init__() + self.dim = along_dim + self.chunks = chunks + self.fn = fn + + def forward(self, x, **kwargs): + if self.chunks == 1: + return self.fn(x, **kwargs) + chunks = x.chunk(self.chunks, dim = self.dim) + return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim) + +class FeedForward(nn.Module): + def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False): + super().__init__() + activation = default(activation, nn.GELU) + + self.glu = glu + self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1)) + self.act = activation() + self.dropout = nn.Dropout(dropout) + self.w2 = nn.Linear(dim * mult, dim) + + def forward(self, x, **kwargs): + if not self.glu: + x = self.w1(x) + x = self.act(x) + else: + x, v = self.w1(x).chunk(2, dim=-1) + x = self.act(x) * v + + x = self.dropout(x) + x = self.w2(x) + return x + +class Attention(nn.Module): + def __init__( + self, + dim, + causal = False, + heads = 8, + dim_head = 64, + local_heads = 0, + local_window_size = 256, + nb_features = None, + feature_redraw_interval = 1000, + generalized_attention = False, + kernel_fn = nn.ReLU(), + dropout = 0., + no_projection = False, + qkv_bias = False, + attn_out_bias = True + ): + super().__init__() + assert dim % heads == 0, 'dimension must be divisible by number of heads' + dim_head = default(dim_head, dim // heads) + inner_dim = dim_head * heads + self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, no_projection = no_projection) + + self.heads = heads + self.global_heads = heads - local_heads + self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None + + self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias) + self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias) + self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias) + self.to_out = nn.Linear(inner_dim, dim, bias = attn_out_bias) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs): + b, n, _, h, gh = *x.shape, self.heads, self.global_heads + + cross_attend = exists(context) + + context = default(context, x) + context_mask = default(context_mask, mask) if not cross_attend else context_mask + + q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) + + attn_outs = [] + + if not empty(q): + if exists(context_mask): + global_mask = context_mask[:, None, :, None] + v.masked_fill_(~global_mask, 0.) + + if exists(pos_emb) and not cross_attend: + q, k = apply_rotary_pos_emb(q, k, pos_emb) + + out = self.fast_attention(q, k, v) + attn_outs.append(out) + + if not empty(lq): + assert not cross_attend, 'local attention is not compatible with cross attention' + out = self.local_attn(lq, lk, lv, input_mask = mask) + attn_outs.append(out) + + out = torch.cat(attn_outs, dim = 1) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return self.dropout(out) + +class SelfAttention(Attention): + def forward(self, *args, context = None, **kwargs): + assert not exists(context), 'self attention should not receive context' + return super().forward(*args, **kwargs) + +class CrossAttention(Attention): + def forward(self, *args, context = None, **kwargs): + assert exists(context), 'cross attention should receive context' + return super().forward(*args, context = context, **kwargs) + +# positional embeddings + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x): + t = torch.arange(x.shape[1], device=x.device) + return self.emb(t) + +# rotary positional embedding helpers + +def rotate_every_two(x): + x = rearrange(x, '... (d j) -> ... d j', j = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d j -> ... (d j)') + +def apply_rotary_pos_emb(q, k, sinu_pos): + sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2) + sin, cos = sinu_pos.unbind(dim = -2) + sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos)) + q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) + return q, k + +# sinusoidal positional embeddings + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + position = torch.arange(0, max_seq_len, dtype=torch.float) + sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + self.register_buffer('emb', emb) + + def forward(self, x): + return self.emb[None, :x.shape[1], :].to(x) + +# performer + +class Performer(nn.Module): + def __init__( + self, + dim, + depth, + heads, + dim_head, + local_attn_heads = 0, + local_window_size = 256, + causal = False, + ff_mult = 4, + nb_features = None, + feature_redraw_interval = 1000, + reversible = False, + ff_chunks = 1, + generalized_attention = False, + kernel_fn = nn.ReLU(), + use_scalenorm = False, + use_rezero = False, + ff_glu = False, + ff_dropout = 0., + attn_dropout = 0., + cross_attend = False, + no_projection = False, + auto_check_redraw = True, + qkv_bias = True, + attn_out_bias = True, + shift_tokens = False + ): + super().__init__() + layers = nn.ModuleList([]) + local_attn_heads = cast_tuple(local_attn_heads) + local_attn_heads = local_attn_heads * depth if len(local_attn_heads) == 1 else local_attn_heads + assert len(local_attn_heads) == depth, 'tuple specifying number of local attention heads per depth must be equal to the total depth' + assert all(map(lambda n: n >= 0 and n <= heads, local_attn_heads)), 'local attention head value must be less than the total number of heads' + + if use_scalenorm: + wrapper_fn = partial(PreScaleNorm, dim) + elif use_rezero: + wrapper_fn = ReZero + else: + wrapper_fn = partial(PreLayerNorm, dim) + + for _, local_heads in zip(range(depth), local_attn_heads): + + attn = SelfAttention(dim, causal = causal, heads = heads, dim_head = dim_head, local_heads = local_heads, local_window_size = local_window_size, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias, attn_out_bias = attn_out_bias) + ff = Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1) + + if shift_tokens: + shift = (0, 1) if causal else (-1, 0, 1) + attn, ff = map(lambda t: PreShiftTokens(shift, t), (attn, ff)) + + attn, ff = map(wrapper_fn, (attn, ff)) + layers.append(nn.ModuleList([attn, ff])) + + if not cross_attend: + continue + + layers.append(nn.ModuleList([ + wrapper_fn(CrossAttention(dim, heads = heads, dim_head = dim_head, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias, attn_out_bias = attn_out_bias)), + wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1)) + ])) + + execute_type = ReversibleSequence if reversible else SequentialSequence + + route_attn = ((True, False),) * depth * (2 if cross_attend else 1) + route_context = ((False, False), (True, False)) * depth + attn_route_map = {'mask': route_attn, 'pos_emb': route_attn} + context_route_map = {'context': route_context, 'context_mask': route_context} if cross_attend else {} + self.net = execute_type(layers, args_route = {**attn_route_map, **context_route_map}) + + # keeping track of when to redraw projections for all attention layers + self.auto_check_redraw = auto_check_redraw + self.proj_updater = ProjectionUpdater(self.net, feature_redraw_interval) + + def fix_projection_matrices_(self): + self.proj_updater.feature_redraw_interval = None + + def forward(self, x, **kwargs): + if self.auto_check_redraw: + self.proj_updater.redraw_projections() + return self.net(x, **kwargs) + +class PerformerLM(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + dim, + depth, + heads, + dim_head = 64, + local_attn_heads = 0, + local_window_size = 256, + causal = False, + ff_mult = 4, + nb_features = None, + feature_redraw_interval = 1000, + reversible = False, + ff_chunks = 1, + ff_glu = False, + emb_dropout = 0., + ff_dropout = 0., + attn_dropout = 0., + generalized_attention = False, + kernel_fn = nn.ReLU(), + use_scalenorm = False, + use_rezero = False, + cross_attend = False, + no_projection = False, + tie_embed = False, + rotary_position_emb = True, + axial_position_emb = False, + axial_position_shape = None, + auto_check_redraw = True, + qkv_bias = False, + attn_out_bias = False, + shift_tokens = False + ): + super().__init__() + local_attn_heads = cast_tuple(local_attn_heads) + + self.max_seq_len = max_seq_len + self.token_emb = nn.Embedding(num_tokens, dim) + + if rotary_position_emb: + self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len) + self.layer_pos_emb = FixedPositionalEmbedding(dim_head, max_seq_len) + elif axial_position_emb: + axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / 64), 64)) + self.pos_emb = AxialPositionalEmbedding(dim, axial_position_shape) + self.layer_pos_emb = Always(None) + else: + self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) + self.layer_pos_emb = Always(None) + + self.dropout = nn.Dropout(emb_dropout) + + self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens) + self.norm = nn.LayerNorm(dim) + self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None + + def check_redraw_projections(self): + self.performer.check_redraw_projections() + + def fix_projection_matrices_(self): + self.performer.fix_projection_matrices_() + + def forward(self, x, return_encodings = False, **kwargs): + b, n, device = *x.shape, x.device + assert n <= self.max_seq_len, f'sequence length {n} must be less than the max sequence length {self.max_seq_len}' + + # token and positional embeddings + x = self.token_emb(x) + x += self.pos_emb(x) + + x = self.dropout(x) + + # performer layers + + layer_pos_emb = self.layer_pos_emb(x) + x = self.performer(x, pos_emb = layer_pos_emb, **kwargs) + + # norm and to logits + x = self.norm(x) + + if return_encodings: + return x + + if exists(self.to_out): + return self.to_out(x) + + return x @ self.token_emb.weight.t() diff --git a/codes/models/lucidrains/performer/reversible.py b/codes/models/lucidrains/performer/reversible.py new file mode 100644 index 00000000..0b740105 --- /dev/null +++ b/codes/models/lucidrains/performer/reversible.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +from operator import itemgetter +from torch.autograd.function import Function +from torch.utils.checkpoint import get_device_states, set_device_states + +# for routing arguments into the functions of the reversible layer +def route_args(router, args, depth): + routed_args = [(dict(), dict()) for _ in range(depth)] + matched_keys = [key for key in args.keys() if key in router] + + for key in matched_keys: + val = args[key] + for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): + new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) + routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) + return routed_args + +# following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html +class Deterministic(nn.Module): + def __init__(self, net): + super().__init__() + self.net = net + self.cpu_state = None + self.cuda_in_fwd = None + self.gpu_devices = None + self.gpu_states = None + + def record_rng(self, *args): + self.cpu_state = torch.get_rng_state() + if torch.cuda._initialized: + self.cuda_in_fwd = True + self.gpu_devices, self.gpu_states = get_device_states(*args) + + def forward(self, *args, record_rng = False, set_rng = False, **kwargs): + if record_rng: + self.record_rng(*args) + + if not set_rng: + return self.net(*args, **kwargs) + + rng_devices = [] + if self.cuda_in_fwd: + rng_devices = self.gpu_devices + + with torch.random.fork_rng(devices=rng_devices, enabled=True): + torch.set_rng_state(self.cpu_state) + if self.cuda_in_fwd: + set_device_states(self.gpu_devices, self.gpu_states) + return self.net(*args, **kwargs) + +# heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py +# once multi-GPU is confirmed working, refactor and send PR back to source +class ReversibleBlock(nn.Module): + def __init__(self, f, g): + super().__init__() + self.f = Deterministic(f) + self.g = Deterministic(g) + + def forward(self, x, f_args = {}, g_args = {}): + x1, x2 = torch.chunk(x, 2, dim=2) + y1, y2 = None, None + + with torch.no_grad(): + y1 = x1 + self.f(x2, record_rng=self.training, **f_args) + y2 = x2 + self.g(y1, record_rng=self.training, **g_args) + + return torch.cat([y1, y2], dim=2) + + def backward_pass(self, y, dy, f_args = {}, g_args = {}): + y1, y2 = torch.chunk(y, 2, dim=2) + del y + + dy1, dy2 = torch.chunk(dy, 2, dim=2) + del dy + + with torch.enable_grad(): + y1.requires_grad = True + gy1 = self.g(y1, set_rng=True, **g_args) + torch.autograd.backward(gy1, dy2) + + with torch.no_grad(): + x2 = y2 - gy1 + del y2, gy1 + + dx1 = dy1 + y1.grad + del dy1 + y1.grad = None + + with torch.enable_grad(): + x2.requires_grad = True + fx2 = self.f(x2, set_rng=True, **f_args) + torch.autograd.backward(fx2, dx1, retain_graph=True) + + with torch.no_grad(): + x1 = y1 - fx2 + del y1, fx2 + + dx2 = dy2 + x2.grad + del dy2 + x2.grad = None + + x = torch.cat([x1, x2.detach()], dim=2) + dx = torch.cat([dx1, dx2], dim=2) + + return x, dx + +class _ReversibleFunction(Function): + @staticmethod + def forward(ctx, x, blocks, args): + ctx.args = args + for block, kwarg in zip(blocks, args): + x = block(x, **kwarg) + ctx.y = x.detach() + ctx.blocks = blocks + return x + + @staticmethod + def backward(ctx, dy): + y = ctx.y + args = ctx.args + for block, kwargs in zip(ctx.blocks[::-1], args[::-1]): + y, dy = block.backward_pass(y, dy, **kwargs) + return dy, None, None + +class SequentialSequence(nn.Module): + def __init__(self, layers, args_route = {}): + super().__init__() + assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' + self.layers = layers + self.args_route = args_route + + def forward(self, x, **kwargs): + args = route_args(self.args_route, kwargs, len(self.layers)) + layers_and_args = list(zip(self.layers, args)) + + for (f, g), (f_args, g_args) in layers_and_args: + x = x + f(x, **f_args) + x = x + g(x, **g_args) + return x + +class ReversibleSequence(nn.Module): + def __init__(self, blocks, args_route = {}): + super().__init__() + self.args_route = args_route + self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) + + def forward(self, x, **kwargs): + x = torch.cat([x, x], dim=-1) + + blocks = self.blocks + args = route_args(self.args_route, kwargs, len(blocks)) + args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) + + out = _ReversibleFunction.apply(x, blocks, args) + return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)