updates for new autoregressive

This commit is contained in:
James Betker 2022-04-08 09:25:21 -06:00
parent 12bac95a17
commit e9f3abcae7

View File

@ -24,7 +24,8 @@ Intermediates = namedtuple('Intermediates', [
LayerIntermediates = namedtuple('Intermediates', [ LayerIntermediates = namedtuple('Intermediates', [
'hiddens', 'hiddens',
'attn_intermediates' 'attn_intermediates',
'past_key_values',
]) ])
@ -589,7 +590,8 @@ class Attention(nn.Module):
sinusoidal_emb=None, sinusoidal_emb=None,
rotary_pos_emb=None, rotary_pos_emb=None,
prev_attn=None, prev_attn=None,
mem=None mem=None,
layer_past=None,
): ):
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists( b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
context) context)
@ -620,6 +622,13 @@ class Attention(nn.Module):
k = rearrange(k, 'b n d -> b () n d') k = rearrange(k, 'b n d -> b () n d')
v = rearrange(v, 'b n (h d) -> b h n d', h=h) v = rearrange(v, 'b n (h d) -> b h n d', h=h)
if layer_past is not None:
past_key, past_value = layer_past
k = torch.cat([past_key, k], dim=-2)
v = torch.cat([past_value, v], dim=-2)
k_cache = k
v_cache = v
if exists(rotary_pos_emb) and not has_context: if exists(rotary_pos_emb) and not has_context:
l = rotary_pos_emb.shape[-1] l = rotary_pos_emb.shape[-1]
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
@ -723,7 +732,7 @@ class Attention(nn.Module):
post_softmax_attn=post_softmax_attn post_softmax_attn=post_softmax_attn
) )
return self.to_out(out), intermediates return self.to_out(out), intermediates, k_cache, v_cache
class AttentionLayers(nn.Module): class AttentionLayers(nn.Module):
@ -770,6 +779,7 @@ class AttentionLayers(nn.Module):
self.dim = dim self.dim = dim
self.depth = depth self.depth = depth
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.causal = causal
rel_pos_bias = 'rel_pos_bias' in attn_kwargs rel_pos_bias = 'rel_pos_bias' in attn_kwargs
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
@ -911,6 +921,8 @@ class AttentionLayers(nn.Module):
mems=None, mems=None,
return_hiddens=False, return_hiddens=False,
norm_scale_shift_inp=None, norm_scale_shift_inp=None,
past_key_values=None,
expected_seq_len=None,
): ):
assert not (self.cross_attend ^ (exists(context) or exists( assert not (self.cross_attend ^ (exists(context) or exists(
@ -929,9 +941,17 @@ class AttentionLayers(nn.Module):
rotary_pos_emb = None rotary_pos_emb = None
if exists(self.rotary_pos_emb): if exists(self.rotary_pos_emb):
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems))) if not self.training and self.causal:
assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
elif expected_seq_len is None:
expected_seq_len = 0
seq_len = x.shape[1]
if past_key_values is not None:
seq_len += past_key_values[0][0].shape[-2]
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
present_key_values = []
cross_attn_count = 0 cross_attn_count = 0
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
if layer_type == 'a': if layer_type == 'a':
@ -944,18 +964,28 @@ class AttentionLayers(nn.Module):
if exists(pre_branch_norm): if exists(pre_branch_norm):
x = pre_branch_norm(x, **norm_args) x = pre_branch_norm(x, **norm_args)
if layer_type == 'a' or layer_type == 'c':
if past_key_values is not None:
layer_kv = past_key_values.pop(0)
layer_past = tuple(s.to(x.device) for s in layer_kv)
else:
layer_past = None
if layer_type == 'a': if layer_type == 'a':
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
prev_attn, layer_mem) prev_attn, layer_mem, layer_past)
elif layer_type == 'c': elif layer_type == 'c':
if exists(full_context): if exists(full_context):
out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
None, prev_attn) None, prev_attn, None, layer_past)
else: else:
out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn) out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
elif layer_type == 'f': elif layer_type == 'f':
out = checkpoint(block, x) out = checkpoint(block, x)
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
present_key_values.append((k.detach(), v.detach()))
if exists(post_branch_norm): if exists(post_branch_norm):
out = post_branch_norm(out, **norm_args) out = post_branch_norm(out, **norm_args)
@ -981,7 +1011,8 @@ class AttentionLayers(nn.Module):
if return_hiddens: if return_hiddens:
intermediates = LayerIntermediates( intermediates = LayerIntermediates(
hiddens=hiddens, hiddens=hiddens,
attn_intermediates=intermediates attn_intermediates=intermediates,
past_key_values=present_key_values
) )
return x, intermediates return x, intermediates
@ -1115,6 +1146,7 @@ class TransformerWrapper(nn.Module):
return_hiddens=False, return_hiddens=False,
return_attn=False, return_attn=False,
mems=None, mems=None,
use_cache=False,
**kwargs **kwargs
): ):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
@ -1147,11 +1179,14 @@ class TransformerWrapper(nn.Module):
hiddens = intermediates.hiddens hiddens = intermediates.hiddens
return out, hiddens return out, hiddens
res = [out]
if return_attn: if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
return out, attn_maps res.append(attn_maps)
if use_cache:
res.append(intermediates.past_key_values)
return out return res
class ContinuousTransformerWrapper(nn.Module): class ContinuousTransformerWrapper(nn.Module):
@ -1191,6 +1226,7 @@ class ContinuousTransformerWrapper(nn.Module):
mask=None, mask=None,
return_attn=False, return_attn=False,
mems=None, mems=None,
use_cache=False,
**kwargs **kwargs
): ):
b, n, _, device = *x.shape, x.device b, n, _, device = *x.shape, x.device
@ -1204,11 +1240,14 @@ class ContinuousTransformerWrapper(nn.Module):
out = self.project_out(x) if not return_embeddings else x out = self.project_out(x) if not return_embeddings else x
res = [out]
if return_attn: if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
return out, attn_maps res.append(attn_maps)
if use_cache:
res.append(intermediates.past_key_values)
return out return tuple(res)
class XTransformer(nn.Module): class XTransformer(nn.Module):