Check in backing changes (which may have broken something?)

This commit is contained in:
James Betker 2021-10-29 17:22:33 -06:00
parent 986fc9628d
commit b476516340

View File

@ -48,6 +48,7 @@ class LayerScale(nn.Module):
scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
@ -78,7 +79,12 @@ class FeedForward(nn.Module):
nn.Linear(dim * mult, dim)
)
def forward(self, x):
def forward(self, x, only_last_two_elements=False):
if only_last_two_elements:
h = x[:, -2:]
h = self.net(h)
return torch.cat([x[:, :-2], h], dim=1)
else:
return self.net(x)
@ -124,15 +130,19 @@ class Attention(nn.Module):
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
def forward(self, x, mask = None, only_last_two_elements=False):
b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax if not self.stable else stable_softmax
# TODO: Q and V do not need to be recomputed for existing elements in intermediate_latents is specified. V would need to be cached though.
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
q = q * self.scale
if only_last_two_elements:
q = q[:, :, -2:]
assert not exists(mask) # Don't know how to resolve this (currently)
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
mask_value = max_neg_value(dots)
@ -187,11 +197,40 @@ class Transformer(nn.Module):
LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout)))
]))
# TODO: Remove this nonsense. I don't want to support reversible sequences and this is just a mess.
execute_type = ReversibleSequence if reversible else SequentialSequence
route_attn = ((True, False),) * depth
attn_route_map = {'mask': route_attn}
self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True)
self.depth = depth
def forward(self, x):
return self.layers(x)
def forward(self, x, return_intermediates=False):
intermediates = []
for attn, ff in self.layers.layers:
x_ff = attn(x)
x = ff(x_ff)
if return_intermediates:
intermediates.append((x_ff, x))
if return_intermediates:
return x, intermediates
else:
return x
def infer_last_two(self, x, prev_intermediates):
"""
Performs an forward pass only on the last two element in the given sequence (allowing it to attend to all other
elements). This is useful for faster autoregressive decoding.
The last two elements are important because in inference, the last element is the prediction candidate and the
second-to-last element is a newly selected element from the autoregressive searching process.
"""
assert(len(prev_intermediates) == self.depth)
new_intermediates = []
for (attn, ff), (int_ff, int_out) in zip(self.layers.layers, prev_intermediates):
x = attn(x, only_last_two_elements=True)
# Note that (x) is now only the last two element in the set. Conjoin it with the int_ff latent to compute the norm.
x_ff = torch.cat([int_ff[:,:-1], x], dim=1)
x = ff(x_ff, only_last_two_elements=True)
new_intermediates.append((x_ff, x))
return x, new_intermediates