forked from mrq/DL-Art-School
672 lines
24 KiB
Python
672 lines
24 KiB
Python
import math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from torch.cuda.amp import autocast
|
|
from einops import rearrange, repeat
|
|
|
|
from functools import partial
|
|
from contextlib import contextmanager
|
|
|
|
from local_attention import LocalAttention
|
|
from axial_positional_embedding import AxialPositionalEmbedding
|
|
from models.lucidrains.performer.reversible import ReversibleSequence, SequentialSequence
|
|
|
|
from distutils.version import LooseVersion
|
|
|
|
TORCH_GE_1_8_0 = LooseVersion(torch.__version__) >= LooseVersion('1.8.0')
|
|
|
|
try:
|
|
from apex import amp
|
|
APEX_AVAILABLE = True
|
|
except:
|
|
APEX_AVAILABLE = False
|
|
|
|
# helpers
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
def empty(tensor):
|
|
return tensor.numel() == 0
|
|
|
|
def default(val, d):
|
|
return val if exists(val) else d
|
|
|
|
@contextmanager
|
|
def null_context():
|
|
yield
|
|
|
|
def cast_tuple(val):
|
|
return (val,) if not isinstance(val, tuple) else val
|
|
|
|
def get_module_device(module):
|
|
return next(module.parameters()).device
|
|
|
|
def find_modules(nn_module, type):
|
|
return [module for module in nn_module.modules() if isinstance(module, type)]
|
|
|
|
class Always(nn.Module):
|
|
def __init__(self, val):
|
|
super().__init__()
|
|
self.val = val
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.val
|
|
|
|
# token shifting helper and classes
|
|
|
|
def shift(t, amount, mask = None):
|
|
if amount == 0:
|
|
return t
|
|
|
|
if exists(mask):
|
|
t = t.masked_fill(~mask[..., None], 0.)
|
|
|
|
return F.pad(t, (0, 0, amount, -amount), value = 0.)
|
|
|
|
class PreShiftTokens(nn.Module):
|
|
def __init__(self, shifts, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
self.shifts = tuple(shifts)
|
|
|
|
def forward(self, x, **kwargs):
|
|
mask = kwargs.get('mask', None)
|
|
shifts = self.shifts
|
|
segments = len(shifts)
|
|
feats_per_shift = x.shape[-1] // segments
|
|
splitted = x.split(feats_per_shift, dim = -1)
|
|
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
|
segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
|
|
x = torch.cat((*segments_to_shift, *rest), dim = -1)
|
|
return self.fn(x, **kwargs)
|
|
|
|
# kernel functions
|
|
|
|
# transcribed from jax to pytorch from
|
|
# https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py
|
|
|
|
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
|
|
b, h, *_ = data.shape
|
|
|
|
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
|
|
|
|
ratio = (projection_matrix.shape[0] ** -0.5)
|
|
|
|
projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
|
|
projection = projection.type_as(data)
|
|
|
|
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
|
|
|
|
diag_data = data ** 2
|
|
diag_data = torch.sum(diag_data, dim=-1)
|
|
diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
|
|
diag_data = diag_data.unsqueeze(dim=-1)
|
|
|
|
if is_query:
|
|
data_dash = ratio * (
|
|
torch.exp(data_dash - diag_data -
|
|
torch.amax(data_dash, dim=-1, keepdim=True)) + eps)
|
|
else:
|
|
data_dash = ratio * (
|
|
torch.exp(data_dash - diag_data - torch.amax(data_dash, dim=(-1, -2), keepdim=True)) + eps)
|
|
|
|
return data_dash.type_as(data)
|
|
|
|
def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None):
|
|
b, h, *_ = data.shape
|
|
|
|
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
|
|
|
|
if projection_matrix is None:
|
|
return kernel_fn(data_normalizer * data) + kernel_epsilon
|
|
|
|
projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
|
|
projection = projection.type_as(data)
|
|
|
|
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
|
|
|
|
data_prime = kernel_fn(data_dash) + kernel_epsilon
|
|
return data_prime.type_as(data)
|
|
|
|
def orthogonal_matrix_chunk(cols, device = None):
|
|
unstructured_block = torch.randn((cols, cols), device = device)
|
|
if TORCH_GE_1_8_0:
|
|
q, r = torch.linalg.qr(unstructured_block.cpu(), mode = 'reduced')
|
|
else:
|
|
q, r = torch.qr(unstructured_block.cpu(), some = True)
|
|
q, r = map(lambda t: t.to(device), (q, r))
|
|
return q.t()
|
|
|
|
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = None):
|
|
nb_full_blocks = int(nb_rows / nb_columns)
|
|
|
|
block_list = []
|
|
|
|
for _ in range(nb_full_blocks):
|
|
q = orthogonal_matrix_chunk(nb_columns, device = device)
|
|
block_list.append(q)
|
|
|
|
remaining_rows = nb_rows - nb_full_blocks * nb_columns
|
|
if remaining_rows > 0:
|
|
q = orthogonal_matrix_chunk(nb_columns, device = device)
|
|
block_list.append(q[:remaining_rows])
|
|
|
|
final_matrix = torch.cat(block_list)
|
|
|
|
if scaling == 0:
|
|
multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1)
|
|
elif scaling == 1:
|
|
multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device)
|
|
else:
|
|
raise ValueError(f'Invalid scaling {scaling}')
|
|
|
|
return torch.diag(multiplier) @ final_matrix
|
|
|
|
# linear attention classes with softmax kernel
|
|
|
|
# non-causal linear attention
|
|
def linear_attention(q, k, v):
|
|
k_cumsum = k.sum(dim = -2)
|
|
D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
|
|
context = torch.einsum('...nd,...ne->...de', k, v)
|
|
out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
|
|
return out
|
|