forked from mrq/DL-Art-School
Check in backing changes (which may have broken something?)
This commit is contained in:
parent
986fc9628d
commit
b476516340
|
@ -48,6 +48,7 @@ class LayerScale(nn.Module):
|
||||||
scale = torch.zeros(1, 1, dim).fill_(init_eps)
|
scale = torch.zeros(1, 1, dim).fill_(init_eps)
|
||||||
self.scale = nn.Parameter(scale)
|
self.scale = nn.Parameter(scale)
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
def forward(self, x, **kwargs):
|
||||||
return self.fn(x, **kwargs) * self.scale
|
return self.fn(x, **kwargs) * self.scale
|
||||||
|
|
||||||
|
@ -78,8 +79,13 @@ class FeedForward(nn.Module):
|
||||||
nn.Linear(dim * mult, dim)
|
nn.Linear(dim * mult, dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, only_last_two_elements=False):
|
||||||
return self.net(x)
|
if only_last_two_elements:
|
||||||
|
h = x[:, -2:]
|
||||||
|
h = self.net(h)
|
||||||
|
return torch.cat([x[:, :-2], h], dim=1)
|
||||||
|
else:
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
|
@ -110,7 +116,7 @@ def stable_softmax(t, dim = -1, alpha = 32 ** 2):
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, dim, seq_len, non_causal_sequence_partition = 0, heads = 8, dim_head = 64, dropout = 0., stable = False):
|
def __init__(self, dim, seq_len, non_causal_sequence_partition = 0, heads = 8, dim_head = 64, dropout = 0., stable = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.seq_len = seq_len
|
self.seq_len = seq_len
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
|
@ -124,15 +130,19 @@ class Attention(nn.Module):
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, mask = None):
|
def forward(self, x, mask = None, only_last_two_elements=False):
|
||||||
b, n, _, h, device = *x.shape, self.heads, x.device
|
b, n, _, h, device = *x.shape, self.heads, x.device
|
||||||
softmax = torch.softmax if not self.stable else stable_softmax
|
softmax = torch.softmax if not self.stable else stable_softmax
|
||||||
|
|
||||||
|
# TODO: Q and V do not need to be recomputed for existing elements in intermediate_latents is specified. V would need to be cached though.
|
||||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||||
|
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
|
||||||
|
if only_last_two_elements:
|
||||||
|
q = q[:, :, -2:]
|
||||||
|
assert not exists(mask) # Don't know how to resolve this (currently)
|
||||||
|
|
||||||
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
||||||
mask_value = max_neg_value(dots)
|
mask_value = max_neg_value(dots)
|
||||||
|
|
||||||
|
@ -187,11 +197,40 @@ class Transformer(nn.Module):
|
||||||
LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout)))
|
LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout)))
|
||||||
]))
|
]))
|
||||||
|
|
||||||
|
# TODO: Remove this nonsense. I don't want to support reversible sequences and this is just a mess.
|
||||||
execute_type = ReversibleSequence if reversible else SequentialSequence
|
execute_type = ReversibleSequence if reversible else SequentialSequence
|
||||||
route_attn = ((True, False),) * depth
|
route_attn = ((True, False),) * depth
|
||||||
attn_route_map = {'mask': route_attn}
|
attn_route_map = {'mask': route_attn}
|
||||||
|
|
||||||
self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True)
|
self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True)
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, return_intermediates=False):
|
||||||
return self.layers(x)
|
intermediates = []
|
||||||
|
for attn, ff in self.layers.layers:
|
||||||
|
x_ff = attn(x)
|
||||||
|
x = ff(x_ff)
|
||||||
|
if return_intermediates:
|
||||||
|
intermediates.append((x_ff, x))
|
||||||
|
if return_intermediates:
|
||||||
|
return x, intermediates
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def infer_last_two(self, x, prev_intermediates):
|
||||||
|
"""
|
||||||
|
Performs an forward pass only on the last two element in the given sequence (allowing it to attend to all other
|
||||||
|
elements). This is useful for faster autoregressive decoding.
|
||||||
|
|
||||||
|
The last two elements are important because in inference, the last element is the prediction candidate and the
|
||||||
|
second-to-last element is a newly selected element from the autoregressive searching process.
|
||||||
|
"""
|
||||||
|
assert(len(prev_intermediates) == self.depth)
|
||||||
|
new_intermediates = []
|
||||||
|
for (attn, ff), (int_ff, int_out) in zip(self.layers.layers, prev_intermediates):
|
||||||
|
x = attn(x, only_last_two_elements=True)
|
||||||
|
# Note that (x) is now only the last two element in the set. Conjoin it with the int_ff latent to compute the norm.
|
||||||
|
x_ff = torch.cat([int_ff[:,:-1], x], dim=1)
|
||||||
|
x = ff(x_ff, only_last_two_elements=True)
|
||||||
|
new_intermediates.append((x_ff, x))
|
||||||
|
return x, new_intermediates
|
||||||
|
|
Loading…
Reference in New Issue
Block a user