Check in backing changes (which may have broken something?)

This commit is contained in:
James Betker 2021-10-29 17:22:33 -06:00
parent 986fc9628d
commit b476516340

View File

@ -48,6 +48,7 @@ class LayerScale(nn.Module):
scale = torch.zeros(1, 1, dim).fill_(init_eps) scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale) self.scale = nn.Parameter(scale)
self.fn = fn self.fn = fn
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale return self.fn(x, **kwargs) * self.scale
@ -78,8 +79,13 @@ class FeedForward(nn.Module):
nn.Linear(dim * mult, dim) nn.Linear(dim * mult, dim)
) )
def forward(self, x): def forward(self, x, only_last_two_elements=False):
return self.net(x) if only_last_two_elements:
h = x[:, -2:]
h = self.net(h)
return torch.cat([x[:, :-2], h], dim=1)
else:
return self.net(x)
def exists(val): def exists(val):
@ -110,7 +116,7 @@ def stable_softmax(t, dim = -1, alpha = 32 ** 2):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, dim, seq_len, non_causal_sequence_partition = 0, heads = 8, dim_head = 64, dropout = 0., stable = False): def __init__(self, dim, seq_len, non_causal_sequence_partition = 0, heads = 8, dim_head = 64, dropout = 0., stable = False):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
self.heads = heads self.heads = heads
self.seq_len = seq_len self.seq_len = seq_len
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
@ -124,15 +130,19 @@ class Attention(nn.Module):
nn.Dropout(dropout) nn.Dropout(dropout)
) )
def forward(self, x, mask = None): def forward(self, x, mask = None, only_last_two_elements=False):
b, n, _, h, device = *x.shape, self.heads, x.device b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax if not self.stable else stable_softmax softmax = torch.softmax if not self.stable else stable_softmax
# TODO: Q and V do not need to be recomputed for existing elements in intermediate_latents is specified. V would need to be cached though.
qkv = self.to_qkv(x).chunk(3, dim = -1) qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
q = q * self.scale q = q * self.scale
if only_last_two_elements:
q = q[:, :, -2:]
assert not exists(mask) # Don't know how to resolve this (currently)
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
mask_value = max_neg_value(dots) mask_value = max_neg_value(dots)
@ -187,11 +197,40 @@ class Transformer(nn.Module):
LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))) LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout)))
])) ]))
# TODO: Remove this nonsense. I don't want to support reversible sequences and this is just a mess.
execute_type = ReversibleSequence if reversible else SequentialSequence execute_type = ReversibleSequence if reversible else SequentialSequence
route_attn = ((True, False),) * depth route_attn = ((True, False),) * depth
attn_route_map = {'mask': route_attn} attn_route_map = {'mask': route_attn}
self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True) self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True)
self.depth = depth
def forward(self, x): def forward(self, x, return_intermediates=False):
return self.layers(x) intermediates = []
for attn, ff in self.layers.layers:
x_ff = attn(x)
x = ff(x_ff)
if return_intermediates:
intermediates.append((x_ff, x))
if return_intermediates:
return x, intermediates
else:
return x
def infer_last_two(self, x, prev_intermediates):
"""
Performs an forward pass only on the last two element in the given sequence (allowing it to attend to all other
elements). This is useful for faster autoregressive decoding.
The last two elements are important because in inference, the last element is the prediction candidate and the
second-to-last element is a newly selected element from the autoregressive searching process.
"""
assert(len(prev_intermediates) == self.depth)
new_intermediates = []
for (attn, ff), (int_ff, int_out) in zip(self.layers.layers, prev_intermediates):
x = attn(x, only_last_two_elements=True)
# Note that (x) is now only the last two element in the set. Conjoin it with the int_ff latent to compute the norm.
x_ff = torch.cat([int_ff[:,:-1], x], dim=1)
x = ff(x_ff, only_last_two_elements=True)
new_intermediates.append((x_ff, x))
return x, new_intermediates