From e9f3abcae7aee5c496487969a3be24e2cc0a6b81 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Fri, 8 Apr 2022 09:25:21 -0600
Subject: [PATCH] updates for new autoregressive

---
 models/xtransformers.py | 67 ++++++++++++++++++++++++++++++++---------
 1 file changed, 53 insertions(+), 14 deletions(-)

diff --git a/models/xtransformers.py b/models/xtransformers.py
index f203cb2..632349b 100644
--- a/models/xtransformers.py
+++ b/models/xtransformers.py
@@ -24,7 +24,8 @@ Intermediates = namedtuple('Intermediates', [
 
 LayerIntermediates = namedtuple('Intermediates', [
     'hiddens',
-    'attn_intermediates'
+    'attn_intermediates',
+    'past_key_values',
 ])
 
 
@@ -589,7 +590,8 @@ class Attention(nn.Module):
             sinusoidal_emb=None,
             rotary_pos_emb=None,
             prev_attn=None,
-            mem=None
+            mem=None,
+            layer_past=None,
     ):
         b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
             context)
@@ -620,6 +622,13 @@ class Attention(nn.Module):
             k = rearrange(k, 'b n d -> b () n d')
             v = rearrange(v, 'b n (h d) -> b h n d', h=h)
 
+        if layer_past is not None:
+            past_key, past_value = layer_past
+            k = torch.cat([past_key, k], dim=-2)
+            v = torch.cat([past_value, v], dim=-2)
+        k_cache = k
+        v_cache = v
+
         if exists(rotary_pos_emb) and not has_context:
             l = rotary_pos_emb.shape[-1]
             (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
@@ -723,7 +732,7 @@ class Attention(nn.Module):
             post_softmax_attn=post_softmax_attn
         )
 
-        return self.to_out(out), intermediates
+        return self.to_out(out), intermediates, k_cache, v_cache
 
 
 class AttentionLayers(nn.Module):
@@ -770,6 +779,7 @@ class AttentionLayers(nn.Module):
         self.dim = dim
         self.depth = depth
         self.layers = nn.ModuleList([])
+        self.causal = causal
 
         rel_pos_bias = 'rel_pos_bias' in attn_kwargs
         self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
@@ -911,6 +921,8 @@ class AttentionLayers(nn.Module):
             mems=None,
             return_hiddens=False,
             norm_scale_shift_inp=None,
+            past_key_values=None,
+            expected_seq_len=None,
     ):
 
         assert not (self.cross_attend ^ (exists(context) or exists(
@@ -929,9 +941,17 @@ class AttentionLayers(nn.Module):
 
         rotary_pos_emb = None
         if exists(self.rotary_pos_emb):
-            max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
+            if not self.training and self.causal:
+                assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
+            elif expected_seq_len is None:
+                expected_seq_len = 0
+            seq_len = x.shape[1]
+            if past_key_values is not None:
+                seq_len += past_key_values[0][0].shape[-2]
+            max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
             rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
 
+        present_key_values = []
         cross_attn_count = 0
         for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
             if layer_type == 'a':
@@ -944,18 +964,28 @@ class AttentionLayers(nn.Module):
             if exists(pre_branch_norm):
                 x = pre_branch_norm(x, **norm_args)
 
+            if layer_type == 'a' or layer_type == 'c':
+                if past_key_values is not None:
+                    layer_kv = past_key_values.pop(0)
+                    layer_past = tuple(s.to(x.device) for s in layer_kv)
+                else:
+                    layer_past = None
+
             if layer_type == 'a':
-                out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
-                                        prev_attn, layer_mem)
+                out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
+                                        prev_attn, layer_mem, layer_past)
             elif layer_type == 'c':
                 if exists(full_context):
-                    out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
-                                            None, prev_attn)
+                    out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
+                                            None, prev_attn, None, layer_past)
                 else:
-                    out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
+                    out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
             elif layer_type == 'f':
                 out = checkpoint(block, x)
 
+            if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
+                present_key_values.append((k.detach(), v.detach()))
+
             if exists(post_branch_norm):
                 out = post_branch_norm(out, **norm_args)
 
@@ -981,7 +1011,8 @@ class AttentionLayers(nn.Module):
         if return_hiddens:
             intermediates = LayerIntermediates(
                 hiddens=hiddens,
-                attn_intermediates=intermediates
+                attn_intermediates=intermediates,
+                past_key_values=present_key_values
             )
 
             return x, intermediates
@@ -1115,6 +1146,7 @@ class TransformerWrapper(nn.Module):
             return_hiddens=False,
             return_attn=False,
             mems=None,
+            use_cache=False,
             **kwargs
     ):
         b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
@@ -1147,11 +1179,14 @@ class TransformerWrapper(nn.Module):
             hiddens = intermediates.hiddens
             return out, hiddens
 
+        res = [out]
         if return_attn:
             attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
-            return out, attn_maps
+            res.append(attn_maps)
+        if use_cache:
+            res.append(intermediates.past_key_values)
 
-        return out
+        return res
 
 
 class ContinuousTransformerWrapper(nn.Module):
@@ -1191,6 +1226,7 @@ class ContinuousTransformerWrapper(nn.Module):
             mask=None,
             return_attn=False,
             mems=None,
+            use_cache=False,
             **kwargs
     ):
         b, n, _, device = *x.shape, x.device
@@ -1204,11 +1240,14 @@ class ContinuousTransformerWrapper(nn.Module):
 
         out = self.project_out(x) if not return_embeddings else x
 
+        res = [out]
         if return_attn:
             attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
-            return out, attn_maps
+            res.append(attn_maps)
+        if use_cache:
+            res.append(intermediates.past_key_values)
 
-        return out
+        return tuple(res)
 
 
 class XTransformer(nn.Module):