fix context
This commit is contained in:
parent
37bdfe82b2
commit
33ef17e9e5
|
@ -874,11 +874,9 @@ class AttentionLayers(nn.Module):
|
|||
x = pre_branch_norm(x)
|
||||
|
||||
if layer_type == 'a':
|
||||
block_fn = functools.partial(block, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
|
||||
out, inter = checkpoint(block_fn, x)
|
||||
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem)
|
||||
elif layer_type == 'c':
|
||||
block_fn = functools.partial(block, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
|
||||
out, inter = checkpoint(block_fn, x)
|
||||
out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
|
||||
elif layer_type == 'f':
|
||||
out = checkpoint(block, x)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user