diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index c7dfc3c3..453de22a 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -874,11 +874,9 @@ class AttentionLayers(nn.Module): x = pre_branch_norm(x) if layer_type == 'a': - block_fn = functools.partial(block, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem) - out, inter = checkpoint(block_fn, x) + out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem) elif layer_type == 'c': - block_fn = functools.partial(block, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn) - out, inter = checkpoint(block_fn, x) + out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn) elif layer_type == 'f': out = checkpoint(block, x)