fix context

This commit is contained in:
James Betker 2022-04-06 00:45:42 -06:00
parent 37bdfe82b2
commit 33ef17e9e5

View File

@ -874,11 +874,9 @@ class AttentionLayers(nn.Module):
x = pre_branch_norm(x)
if layer_type == 'a':
block_fn = functools.partial(block, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
out, inter = checkpoint(block_fn, x)
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem)
elif layer_type == 'c':
block_fn = functools.partial(block, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
out, inter = checkpoint(block_fn, x)
out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
elif layer_type == 'f':
out = checkpoint(block, x)