Fix multiway checkpointing
This commit is contained in:
parent
22438a8525
commit
aa36203042
|
@ -109,7 +109,11 @@ class EncoderLayer(nn.Module):
|
|||
def residual_connection(self, x, residual):
|
||||
return residual * self.alpha + x
|
||||
|
||||
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None):
|
||||
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None):
|
||||
if multiway_split_position is not None:
|
||||
assert self.args.multiway
|
||||
self.apply(set_split_position(multiway_split_position))
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
|
||||
|
||||
|
@ -360,7 +364,7 @@ class Encoder(nn.Module):
|
|||
l_aux = []
|
||||
for layer in self.layers:
|
||||
x, l_aux_i = layer(
|
||||
x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias
|
||||
x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias, multiway_split_position=multiway_split_position
|
||||
)
|
||||
if return_all_hiddens:
|
||||
assert encoder_states is not None
|
||||
|
|
Loading…
Reference in New Issue
Block a user