diff --git a/torchscale/architecture/encoder.py b/torchscale/architecture/encoder.py index a1b9568..dad97ef 100644 --- a/torchscale/architecture/encoder.py +++ b/torchscale/architecture/encoder.py @@ -109,7 +109,11 @@ class EncoderLayer(nn.Module): def residual_connection(self, x, residual): return residual * self.alpha + x - def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None): + def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None): + if multiway_split_position is not None: + assert self.args.multiway + self.apply(set_split_position(multiway_split_position)) + if attn_mask is not None: attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) @@ -360,7 +364,7 @@ class Encoder(nn.Module): l_aux = [] for layer in self.layers: x, l_aux_i = layer( - x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias + x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias, multiway_split_position=multiway_split_position ) if return_all_hiddens: assert encoder_states is not None