2021-08-05 02:07:45 +00:00
|
|
|
from inspect import isfunction
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch import nn, einsum
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from einops import rearrange
|
|
|
|
|
|
|
|
# helpers
|
|
|
|
from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence
|
2021-10-29 23:29:49 +00:00
|
|
|
from utils.util import checkpoint
|
2021-08-05 02:07:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
def exists(val):
|
|
|
|
return val is not None
|
|
|
|
|
|
|
|
|
|
|
|
def default(val, d):
|
|
|
|
return val if exists(val) else d
|
|
|
|
|
|
|
|
|
|
|
|
def cast_tuple(val, depth = 1):
|
|
|
|
if isinstance(val, list):
|
|
|
|
val = tuple(val)
|
|
|
|
return val if isinstance(val, tuple) else (val,) * depth
|
|
|
|
|
|
|
|
|
|
|
|
class DivideMax(nn.Module):
|
|
|
|
def __init__(self, dim):
|
|
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
maxes = x.amax(dim = self.dim, keepdim = True)
|
|
|
|
return x / maxes
|
|
|
|
|
|
|
|
|
|
|
|
# https://arxiv.org/abs/2103.17239
|
|
|
|
class LayerScale(nn.Module):
|
|
|
|
def __init__(self, dim, depth, fn):
|
|
|
|
super().__init__()
|
|
|
|
if depth <= 18:
|
|
|
|
init_eps = 0.1
|
|
|
|
elif depth > 18 and depth <= 24:
|
|
|
|
init_eps = 1e-5
|
|
|
|
else:
|
|
|
|
init_eps = 1e-6
|
|
|
|
|
|
|
|
scale = torch.zeros(1, 1, dim).fill_(init_eps)
|
|
|
|
self.scale = nn.Parameter(scale)
|
|
|
|
self.fn = fn
|
2021-10-29 23:22:33 +00:00
|
|
|
|
2021-08-05 02:07:45 +00:00
|
|
|
def forward(self, x, **kwargs):
|
|
|
|
return self.fn(x, **kwargs) * self.scale
|
|
|
|
|
|
|
|
|
|
|
|
class PreNorm(nn.Module):
|
|
|
|
def __init__(self, dim, fn):
|
|
|
|
super().__init__()
|
|
|
|
self.norm = nn.LayerNorm(dim)
|
|
|
|
self.fn = fn
|
|
|
|
|
|
|
|
def forward(self, x, **kwargs):
|
|
|
|
return self.fn(self.norm(x), **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
class GEGLU(nn.Module):
|
|
|
|
def forward(self, x):
|
|
|
|
x, gates = x.chunk(2, dim = -1)
|
|
|
|
return x * F.gelu(gates)
|
|
|
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
|
def __init__(self, dim, dropout = 0., mult = 4.):
|
|
|
|
super().__init__()
|
|
|
|
self.net = nn.Sequential(
|
|
|
|
nn.Linear(dim, dim * mult * 2),
|
|
|
|
GEGLU(),
|
|
|
|
nn.Dropout(dropout),
|
|
|
|
nn.Linear(dim * mult, dim)
|
|
|
|
)
|
|
|
|
|
2021-10-29 23:22:33 +00:00
|
|
|
def forward(self, x, only_last_two_elements=False):
|
|
|
|
if only_last_two_elements:
|
|
|
|
h = x[:, -2:]
|
|
|
|
h = self.net(h)
|
|
|
|
return torch.cat([x[:, :-2], h], dim=1)
|
|
|
|
else:
|
|
|
|
return self.net(x)
|
2021-08-05 02:07:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
def exists(val):
|
|
|
|
return val is not None
|
|
|
|
|
|
|
|
|
|
|
|
def uniq(arr):
|
|
|
|
return{el: True for el in arr}.keys()
|
|
|
|
|
|
|
|
|
|
|
|
def default(val, d):
|
|
|
|
if exists(val):
|
|
|
|
return val
|
|
|
|
return d() if isfunction(d) else d
|
|
|
|
|
|
|
|
|
|
|
|
def max_neg_value(t):
|
|
|
|
return -torch.finfo(t.dtype).max
|
|
|
|
|
|
|
|
|
|
|
|
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
|
|
|
|
t = t / alpha
|
|
|
|
t = t - torch.amax(t, dim = dim, keepdim = True)
|
|
|
|
return (t * alpha).softmax(dim = dim)
|
|
|
|
|
|
|
|
|
|
|
|
# classes
|
|
|
|
class Attention(nn.Module):
|
2021-08-13 21:02:18 +00:00
|
|
|
def __init__(self, dim, seq_len, non_causal_sequence_partition = 0, heads = 8, dim_head = 64, dropout = 0., stable = False):
|
2021-08-05 02:07:45 +00:00
|
|
|
super().__init__()
|
2021-10-29 23:22:33 +00:00
|
|
|
inner_dim = dim_head * heads
|
2021-08-05 02:07:45 +00:00
|
|
|
self.heads = heads
|
|
|
|
self.seq_len = seq_len
|
|
|
|
self.scale = dim_head ** -0.5
|
|
|
|
|
|
|
|
self.stable = stable
|
2021-08-13 21:02:18 +00:00
|
|
|
self.non_causal_sequence_partition = non_causal_sequence_partition
|
2021-08-05 02:07:45 +00:00
|
|
|
|
|
|
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
|
|
|
self.to_out = nn.Sequential(
|
|
|
|
nn.Linear(inner_dim, dim),
|
|
|
|
nn.Dropout(dropout)
|
|
|
|
)
|
|
|
|
|
2021-10-29 23:22:33 +00:00
|
|
|
def forward(self, x, mask = None, only_last_two_elements=False):
|
2021-08-05 02:07:45 +00:00
|
|
|
b, n, _, h, device = *x.shape, self.heads, x.device
|
|
|
|
softmax = torch.softmax if not self.stable else stable_softmax
|
|
|
|
|
2021-10-29 23:22:33 +00:00
|
|
|
# TODO: Q and V do not need to be recomputed for existing elements in intermediate_latents is specified. V would need to be cached though.
|
2021-08-05 02:07:45 +00:00
|
|
|
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
|
|
|
q = q * self.scale
|
|
|
|
|
2021-10-29 23:22:33 +00:00
|
|
|
if only_last_two_elements:
|
|
|
|
q = q[:, :, -2:]
|
|
|
|
assert not exists(mask) # Don't know how to resolve this (currently)
|
|
|
|
|
2021-08-05 02:07:45 +00:00
|
|
|
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
|
|
|
mask_value = max_neg_value(dots)
|
|
|
|
|
|
|
|
if exists(mask):
|
|
|
|
mask = rearrange(mask, 'b j -> b () () j')
|
|
|
|
dots.masked_fill_(~mask, mask_value)
|
|
|
|
del mask
|
|
|
|
|
2021-08-13 21:02:18 +00:00
|
|
|
i, j = dots.shape[-2:]
|
|
|
|
mask = torch.ones(i, j, device = device).triu_(j - i + 1)
|
|
|
|
if self.non_causal_sequence_partition > 0:
|
|
|
|
non_causal_mask = torch.ones((i, j), device=device)
|
|
|
|
non_causal_mask[:, :self.non_causal_sequence_partition] = 0
|
|
|
|
mask = mask * non_causal_mask
|
|
|
|
|
|
|
|
dots.masked_fill_(mask.bool(), mask_value)
|
2021-08-05 02:07:45 +00:00
|
|
|
|
|
|
|
attn = softmax(dots, dim=-1)
|
|
|
|
|
|
|
|
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
2021-08-09 20:58:35 +00:00
|
|
|
out = self.to_out(out)
|
2021-08-05 02:07:45 +00:00
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
dim,
|
|
|
|
depth,
|
|
|
|
seq_len,
|
|
|
|
reversible = False,
|
|
|
|
heads = 8,
|
|
|
|
dim_head = 64,
|
|
|
|
ff_mult = 4,
|
|
|
|
attn_dropout = 0.,
|
|
|
|
ff_dropout = 0.,
|
|
|
|
sparse_attn = False,
|
2021-08-13 21:02:18 +00:00
|
|
|
stable = False,
|
|
|
|
non_causal_sequence_partition=0,
|
2021-08-05 02:07:45 +00:00
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
layers = nn.ModuleList([])
|
|
|
|
sparse_layer = cast_tuple(sparse_attn, depth)
|
|
|
|
|
|
|
|
for ind, sparse_attn in zip(range(depth), sparse_layer):
|
2021-08-13 21:02:18 +00:00
|
|
|
attn = Attention(dim, stable=stable, non_causal_sequence_partition = non_causal_sequence_partition, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
|
2021-08-05 02:07:45 +00:00
|
|
|
|
|
|
|
layers.append(nn.ModuleList([
|
|
|
|
LayerScale(dim, ind + 1, PreNorm(dim, attn)),
|
|
|
|
LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout)))
|
|
|
|
]))
|
|
|
|
|
2021-10-29 23:22:33 +00:00
|
|
|
# TODO: Remove this nonsense. I don't want to support reversible sequences and this is just a mess.
|
2021-08-05 02:07:45 +00:00
|
|
|
execute_type = ReversibleSequence if reversible else SequentialSequence
|
|
|
|
route_attn = ((True, False),) * depth
|
|
|
|
attn_route_map = {'mask': route_attn}
|
|
|
|
|
2021-10-27 19:10:07 +00:00
|
|
|
self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True)
|
2021-10-29 23:22:33 +00:00
|
|
|
self.depth = depth
|
|
|
|
|
|
|
|
def forward(self, x, return_intermediates=False):
|
|
|
|
intermediates = []
|
|
|
|
for attn, ff in self.layers.layers:
|
2021-10-29 23:27:51 +00:00
|
|
|
x_ff = x + checkpoint(attn, x)
|
2021-10-29 23:29:49 +00:00
|
|
|
x = x_ff + ff(x_ff)
|
2021-10-29 23:22:33 +00:00
|
|
|
if return_intermediates:
|
|
|
|
intermediates.append((x_ff, x))
|
|
|
|
if return_intermediates:
|
|
|
|
return x, intermediates
|
|
|
|
else:
|
|
|
|
return x
|
|
|
|
|
|
|
|
def infer_last_two(self, x, prev_intermediates):
|
|
|
|
"""
|
|
|
|
Performs an forward pass only on the last two element in the given sequence (allowing it to attend to all other
|
|
|
|
elements). This is useful for faster autoregressive decoding.
|
|
|
|
|
|
|
|
The last two elements are important because in inference, the last element is the prediction candidate and the
|
|
|
|
second-to-last element is a newly selected element from the autoregressive searching process.
|
|
|
|
"""
|
|
|
|
assert(len(prev_intermediates) == self.depth)
|
|
|
|
new_intermediates = []
|
|
|
|
for (attn, ff), (int_ff, int_out) in zip(self.layers.layers, prev_intermediates):
|
2021-10-30 22:59:18 +00:00
|
|
|
x_ff = attn(x, only_last_two_elements=True)
|
2021-10-29 23:22:33 +00:00
|
|
|
# Note that (x) is now only the last two element in the set. Conjoin it with the int_ff latent to compute the norm.
|
2021-10-30 22:59:18 +00:00
|
|
|
x_ff = x + torch.cat([int_ff[:,:-1], x_ff], dim=1)
|
|
|
|
x = x_ff + ff(x_ff, only_last_two_elements=True)
|
2021-10-29 23:22:33 +00:00
|
|
|
new_intermediates.append((x_ff, x))
|
|
|
|
return x, new_intermediates
|