From 33ef17e9e5d7ca4e2cc70e2c099f7ee2f4f88239 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 6 Apr 2022 00:45:42 -0600 Subject: [PATCH] fix context --- codes/models/lucidrains/x_transformers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index c7dfc3c3..453de22a 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -874,11 +874,9 @@ class AttentionLayers(nn.Module): x = pre_branch_norm(x) if layer_type == 'a': - block_fn = functools.partial(block, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem) - out, inter = checkpoint(block_fn, x) + out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem) elif layer_type == 'c': - block_fn = functools.partial(block, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn) - out, inter = checkpoint(block_fn, x) + out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn) elif layer_type == 'f': out = checkpoint(block, x)