Fix multiway checkpointing

This commit is contained in:
shumingma 2022-12-27 22:32:02 -08:00
parent 22438a8525
commit aa36203042

View File

@ -109,7 +109,11 @@ class EncoderLayer(nn.Module):
def residual_connection(self, x, residual):
return residual * self.alpha + x
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None):
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None):
if multiway_split_position is not None:
assert self.args.multiway
self.apply(set_split_position(multiway_split_position))
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
@ -360,7 +364,7 @@ class Encoder(nn.Module):
l_aux = []
for layer in self.layers:
x, l_aux_i = layer(
x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias
x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias, multiway_split_position=multiway_split_position
)
if return_all_hiddens:
assert encoder_states is not None