DL-Art-School/codes/models/lucidrains/performer/performer_pytorch.py
2022-01-09 22:32:50 -07:00

672 lines
24 KiB
Python

import math
import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import autocast
from einops import rearrange, repeat
from functools import partial
from contextlib import contextmanager
from local_attention import LocalAttention
from axial_positional_embedding import AxialPositionalEmbedding
from models.lucidrains.performer.reversible import ReversibleSequence, SequentialSequence
from distutils.version import LooseVersion
TORCH_GE_1_8_0 = LooseVersion(torch.__version__) >= LooseVersion('1.8.0')
try:
from apex import amp
APEX_AVAILABLE = True
except:
APEX_AVAILABLE = False
# helpers
def exists(val):
return val is not None
def empty(tensor):
return tensor.numel() == 0
def default(val, d):
return val if exists(val) else d
@contextmanager
def null_context():
yield
def cast_tuple(val):
return (val,) if not isinstance(val, tuple) else val
def get_module_device(module):
return next(module.parameters()).device
def find_modules(nn_module, type):
return [module for module in nn_module.modules() if isinstance(module, type)]
class Always(nn.Module):
def __init__(self, val):
super().__init__()
self.val = val
def forward(self, *args, **kwargs):
return self.val
# token shifting helper and classes
def shift(t, amount, mask = None):
if amount == 0:
return t
if exists(mask):
t = t.masked_fill(~mask[..., None], 0.)
return F.pad(t, (0, 0, amount, -amount), value = 0.)
class PreShiftTokens(nn.Module):
def __init__(self, shifts, fn):
super().__init__()
self.fn = fn
self.shifts = tuple(shifts)
def forward(self, x, **kwargs):
mask = kwargs.get('mask', None)
shifts = self.shifts
segments = len(shifts)
feats_per_shift = x.shape[-1] // segments
splitted = x.split(feats_per_shift, dim = -1)
segments_to_shift, rest = splitted[:segments], splitted[segments:]
segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
x = torch.cat((*segments_to_shift, *rest), dim = -1)
return self.fn(x, **kwargs)
# kernel functions
# transcribed from jax to pytorch from
# https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
b, h, *_ = data.shape
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
ratio = (projection_matrix.shape[0] ** -0.5)
projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
projection = projection.type_as(data)
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
diag_data = data ** 2
diag_data = torch.sum(diag_data, dim=-1)
diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
diag_data = diag_data.unsqueeze(dim=-1)
if is_query:
data_dash = ratio * (
torch.exp(data_dash - diag_data -
torch.amax(data_dash, dim=-1, keepdim=True)) + eps)
else:
data_dash = ratio * (
torch.exp(data_dash - diag_data - torch.amax(data_dash, dim=(-1, -2), keepdim=True)) + eps)
return data_dash.type_as(data)
def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None):
b, h, *_ = data.shape
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
if projection_matrix is None:
return kernel_fn(data_normalizer * data) + kernel_epsilon
projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
projection = projection.type_as(data)
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
data_prime = kernel_fn(data_dash) + kernel_epsilon
return data_prime.type_as(data)
def orthogonal_matrix_chunk(cols, device = None):
unstructured_block = torch.randn((cols, cols), device = device)
if TORCH_GE_1_8_0:
q, r = torch.linalg.qr(unstructured_block.cpu(), mode = 'reduced')
else:
q, r = torch.qr(unstructured_block.cpu(), some = True)
q, r = map(lambda t: t.to(device), (q, r))
return q.t()
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = None):
nb_full_blocks = int(nb_rows / nb_columns)
block_list = []
for _ in range(nb_full_blocks):
q = orthogonal_matrix_chunk(nb_columns, device = device)
block_list.append(q)
remaining_rows = nb_rows - nb_full_blocks * nb_columns
if remaining_rows > 0:
q = orthogonal_matrix_chunk(nb_columns, device = device)
block_list.append(q[:remaining_rows])
final_matrix = torch.cat(block_list)
if scaling == 0:
multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1)
elif scaling == 1:
multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device)
else:
raise ValueError(f'Invalid scaling {scaling}')
return torch.diag(multiplier) @ final_matrix
# linear attention classes with softmax kernel
# non-causal linear attention
def linear_attention(q, k, v):
k_cumsum = k.sum(dim = -2)
D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
context = torch.einsum('...nd,...ne->...de', k, v)
out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
return out