diff --git a/codes/models/audio/tts/autoregressive_codegen.py b/codes/models/audio/tts/autoregressive_codegen.py index acedc4a4..f69f99d1 100644 --- a/codes/models/audio/tts/autoregressive_codegen.py +++ b/codes/models/audio/tts/autoregressive_codegen.py @@ -86,7 +86,12 @@ class InferenceModel(GPT2PreTrainedModel): assert labels is None # Training not supported by this inference model. return_dict = return_dict if return_dict is not None else self.config.use_return_dict - hidden_states = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True) + out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values, use_cache=use_cache) + if use_cache: + hidden_states, present_key_values = out + else: + hidden_states = out + present_key_values = None logits = self.transformer.decoder.to_logits(hidden_states) if not return_dict: @@ -95,7 +100,7 @@ class InferenceModel(GPT2PreTrainedModel): return CausalLMOutputWithCrossAttentions( loss=None, logits=logits, - past_key_values=None, + past_key_values=present_key_values, hidden_states=hidden_states, attentions=None, cross_attentions=None, @@ -259,7 +264,7 @@ class AutoregressiveCodegen(nn.Module): inference_model.store_context(full_context) gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN, - max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, + max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, use_cache=True, **hf_generate_kwargs) return gen.sequences diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index 038be766..6027b38f 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -11,12 +11,10 @@ from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange from entmax import entmax15 +from torch.utils.checkpoint import checkpoint from x_transformers.autoregressive_wrapper import AutoregressiveWrapper -# constants -from utils.util import checkpoint - DEFAULT_DIM_HEAD = 64 Intermediates = namedtuple('Intermediates', [ @@ -26,7 +24,8 @@ Intermediates = namedtuple('Intermediates', [ LayerIntermediates = namedtuple('Intermediates', [ 'hiddens', - 'attn_intermediates' + 'attn_intermediates', + 'past_key_values', ]) @@ -591,7 +590,8 @@ class Attention(nn.Module): sinusoidal_emb=None, rotary_pos_emb=None, prev_attn=None, - mem=None + mem=None, + layer_past=None, ): b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists( context) @@ -622,6 +622,13 @@ class Attention(nn.Module): k = rearrange(k, 'b n d -> b () n d') v = rearrange(v, 'b n (h d) -> b h n d', h=h) + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat([past_key, k], dim=-2) + v = torch.cat([past_value, v], dim=-2) + k_cache = k + v_cache = v + if exists(rotary_pos_emb) and not has_context: l = rotary_pos_emb.shape[-1] (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) @@ -725,7 +732,7 @@ class Attention(nn.Module): post_softmax_attn=post_softmax_attn ) - return self.to_out(out), intermediates + return self.to_out(out), intermediates, k_cache, v_cache class AttentionLayers(nn.Module): @@ -913,6 +920,7 @@ class AttentionLayers(nn.Module): mems=None, return_hiddens=False, norm_scale_shift_inp=None, + past_key_values=None, ): assert not (self.cross_attend ^ (exists(context) or exists( @@ -931,9 +939,13 @@ class AttentionLayers(nn.Module): rotary_pos_emb = None if exists(self.rotary_pos_emb): - max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems))) + seq_len = x.shape[1] + if past_key_values is not None: + seq_len += past_key_values[0][0].shape[-2] + max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems))) rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) + present_key_values = [] cross_attn_count = 0 for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): if layer_type == 'a': @@ -946,18 +958,28 @@ class AttentionLayers(nn.Module): if exists(pre_branch_norm): x = pre_branch_norm(x, **norm_args) + if layer_type == 'a' or layer_type == 'c': + if past_key_values is not None: + layer_kv = past_key_values.pop(0) + layer_past = tuple(s.to(x.device) for s in layer_kv) + else: + layer_past = None + if layer_type == 'a': - out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, - prev_attn, layer_mem) + out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, + prev_attn, layer_mem, layer_past) elif layer_type == 'c': if exists(full_context): - out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, - None, prev_attn) + out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, + None, prev_attn, None, layer_past) else: - out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn) + out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past) elif layer_type == 'f': out = checkpoint(block, x) + if layer_type == 'a' or layer_type == 'c' and present_key_values is not None: + present_key_values.append((k.detach(), v.detach())) + if exists(post_branch_norm): out = post_branch_norm(out, **norm_args) @@ -983,7 +1005,8 @@ class AttentionLayers(nn.Module): if return_hiddens: intermediates = LayerIntermediates( hiddens=hiddens, - attn_intermediates=intermediates + attn_intermediates=intermediates, + past_key_values=present_key_values ) return x, intermediates @@ -1117,6 +1140,7 @@ class TransformerWrapper(nn.Module): return_hiddens=False, return_attn=False, mems=None, + use_cache=False, **kwargs ): b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens @@ -1149,11 +1173,14 @@ class TransformerWrapper(nn.Module): hiddens = intermediates.hiddens return out, hiddens + res = [out] if return_attn: attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) - return out, attn_maps + res.append(attn_maps) + if use_cache: + res.append(intermediates.past_key_values) - return out + return res class ContinuousTransformerWrapper(nn.Module): @@ -1193,6 +1220,7 @@ class ContinuousTransformerWrapper(nn.Module): mask=None, return_attn=False, mems=None, + use_cache=False, **kwargs ): b, n, _, device = *x.shape, x.device @@ -1206,11 +1234,14 @@ class ContinuousTransformerWrapper(nn.Module): out = self.project_out(x) if not return_embeddings else x + res = [out] if return_attn: attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) - return out, attn_maps + res.append(attn_maps) + if use_cache: + res.append(intermediates.past_key_values) - return out + return tuple(res) class XTransformer(nn.Module):