autoregressive_codegen: support key_value caching for faster inference
This commit is contained in:
parent
d05e162f95
commit
e634996a9c
|
@ -86,7 +86,12 @@ class InferenceModel(GPT2PreTrainedModel):
|
|||
assert labels is None # Training not supported by this inference model.
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
hidden_states = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True)
|
||||
out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values, use_cache=use_cache)
|
||||
if use_cache:
|
||||
hidden_states, present_key_values = out
|
||||
else:
|
||||
hidden_states = out
|
||||
present_key_values = None
|
||||
logits = self.transformer.decoder.to_logits(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
|
@ -95,7 +100,7 @@ class InferenceModel(GPT2PreTrainedModel):
|
|||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=None,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
past_key_values=present_key_values,
|
||||
hidden_states=hidden_states,
|
||||
attentions=None,
|
||||
cross_attentions=None,
|
||||
|
@ -259,7 +264,7 @@ class AutoregressiveCodegen(nn.Module):
|
|||
inference_model.store_context(full_context)
|
||||
|
||||
gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
||||
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
|
||||
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, use_cache=True,
|
||||
**hf_generate_kwargs)
|
||||
return gen.sequences
|
||||
|
||||
|
|
|
@ -11,12 +11,10 @@ from einops import rearrange, repeat, reduce
|
|||
from einops.layers.torch import Rearrange
|
||||
|
||||
from entmax import entmax15
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
||||
|
||||
# constants
|
||||
from utils.util import checkpoint
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
Intermediates = namedtuple('Intermediates', [
|
||||
|
@ -26,7 +24,8 @@ Intermediates = namedtuple('Intermediates', [
|
|||
|
||||
LayerIntermediates = namedtuple('Intermediates', [
|
||||
'hiddens',
|
||||
'attn_intermediates'
|
||||
'attn_intermediates',
|
||||
'past_key_values',
|
||||
])
|
||||
|
||||
|
||||
|
@ -591,7 +590,8 @@ class Attention(nn.Module):
|
|||
sinusoidal_emb=None,
|
||||
rotary_pos_emb=None,
|
||||
prev_attn=None,
|
||||
mem=None
|
||||
mem=None,
|
||||
layer_past=None,
|
||||
):
|
||||
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
|
||||
context)
|
||||
|
@ -622,6 +622,13 @@ class Attention(nn.Module):
|
|||
k = rearrange(k, 'b n d -> b () n d')
|
||||
v = rearrange(v, 'b n (h d) -> b h n d', h=h)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
k = torch.cat([past_key, k], dim=-2)
|
||||
v = torch.cat([past_value, v], dim=-2)
|
||||
k_cache = k
|
||||
v_cache = v
|
||||
|
||||
if exists(rotary_pos_emb) and not has_context:
|
||||
l = rotary_pos_emb.shape[-1]
|
||||
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
|
||||
|
@ -725,7 +732,7 @@ class Attention(nn.Module):
|
|||
post_softmax_attn=post_softmax_attn
|
||||
)
|
||||
|
||||
return self.to_out(out), intermediates
|
||||
return self.to_out(out), intermediates, k_cache, v_cache
|
||||
|
||||
|
||||
class AttentionLayers(nn.Module):
|
||||
|
@ -913,6 +920,7 @@ class AttentionLayers(nn.Module):
|
|||
mems=None,
|
||||
return_hiddens=False,
|
||||
norm_scale_shift_inp=None,
|
||||
past_key_values=None,
|
||||
):
|
||||
|
||||
assert not (self.cross_attend ^ (exists(context) or exists(
|
||||
|
@ -931,9 +939,13 @@ class AttentionLayers(nn.Module):
|
|||
|
||||
rotary_pos_emb = None
|
||||
if exists(self.rotary_pos_emb):
|
||||
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
|
||||
seq_len = x.shape[1]
|
||||
if past_key_values is not None:
|
||||
seq_len += past_key_values[0][0].shape[-2]
|
||||
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)))
|
||||
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
||||
|
||||
present_key_values = []
|
||||
cross_attn_count = 0
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
||||
if layer_type == 'a':
|
||||
|
@ -946,18 +958,28 @@ class AttentionLayers(nn.Module):
|
|||
if exists(pre_branch_norm):
|
||||
x = pre_branch_norm(x, **norm_args)
|
||||
|
||||
if layer_type == 'a' or layer_type == 'c':
|
||||
if past_key_values is not None:
|
||||
layer_kv = past_key_values.pop(0)
|
||||
layer_past = tuple(s.to(x.device) for s in layer_kv)
|
||||
else:
|
||||
layer_past = None
|
||||
|
||||
if layer_type == 'a':
|
||||
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
||||
prev_attn, layer_mem)
|
||||
out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
||||
prev_attn, layer_mem, layer_past)
|
||||
elif layer_type == 'c':
|
||||
if exists(full_context):
|
||||
out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
|
||||
None, prev_attn)
|
||||
out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
|
||||
None, prev_attn, None, layer_past)
|
||||
else:
|
||||
out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
|
||||
out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
|
||||
elif layer_type == 'f':
|
||||
out = checkpoint(block, x)
|
||||
|
||||
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
|
||||
present_key_values.append((k.detach(), v.detach()))
|
||||
|
||||
if exists(post_branch_norm):
|
||||
out = post_branch_norm(out, **norm_args)
|
||||
|
||||
|
@ -983,7 +1005,8 @@ class AttentionLayers(nn.Module):
|
|||
if return_hiddens:
|
||||
intermediates = LayerIntermediates(
|
||||
hiddens=hiddens,
|
||||
attn_intermediates=intermediates
|
||||
attn_intermediates=intermediates,
|
||||
past_key_values=present_key_values
|
||||
)
|
||||
|
||||
return x, intermediates
|
||||
|
@ -1117,6 +1140,7 @@ class TransformerWrapper(nn.Module):
|
|||
return_hiddens=False,
|
||||
return_attn=False,
|
||||
mems=None,
|
||||
use_cache=False,
|
||||
**kwargs
|
||||
):
|
||||
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
||||
|
@ -1149,11 +1173,14 @@ class TransformerWrapper(nn.Module):
|
|||
hiddens = intermediates.hiddens
|
||||
return out, hiddens
|
||||
|
||||
res = [out]
|
||||
if return_attn:
|
||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||
return out, attn_maps
|
||||
res.append(attn_maps)
|
||||
if use_cache:
|
||||
res.append(intermediates.past_key_values)
|
||||
|
||||
return out
|
||||
return res
|
||||
|
||||
|
||||
class ContinuousTransformerWrapper(nn.Module):
|
||||
|
@ -1193,6 +1220,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|||
mask=None,
|
||||
return_attn=False,
|
||||
mems=None,
|
||||
use_cache=False,
|
||||
**kwargs
|
||||
):
|
||||
b, n, _, device = *x.shape, x.device
|
||||
|
@ -1206,11 +1234,14 @@ class ContinuousTransformerWrapper(nn.Module):
|
|||
|
||||
out = self.project_out(x) if not return_embeddings else x
|
||||
|
||||
res = [out]
|
||||
if return_attn:
|
||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||
return out, attn_maps
|
||||
res.append(attn_maps)
|
||||
if use_cache:
|
||||
res.append(intermediates.past_key_values)
|
||||
|
||||
return out
|
||||
return tuple(res)
|
||||
|
||||
|
||||
class XTransformer(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue
Block a user