diff --git a/examples/fairseq/models/machine_translation.py b/examples/fairseq/models/machine_translation.py index 7c94ce8..9063da3 100644 --- a/examples/fairseq/models/machine_translation.py +++ b/examples/fairseq/models/machine_translation.py @@ -371,7 +371,7 @@ class MTEncoder(Encoder, FairseqEncoder): ) def reorder_encoder_out(self, encoder_out, new_order): - new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order) + new_encoder_out = encoder_out["encoder_out"].index_select(0, new_order) new_encoder_embedding = encoder_out["encoder_embedding"].index_select( 0, new_order ) @@ -382,7 +382,7 @@ class MTEncoder(Encoder, FairseqEncoder): encoder_states = encoder_out["encoder_states"] if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): - encoder_states[idx] = state.index_select(1, new_order) + encoder_states[idx] = state.index_select(0, new_order) return { "encoder_out": new_encoder_out, # T x B x C diff --git a/torchscale/architecture/decoder.py b/torchscale/architecture/decoder.py index f996822..942a0cb 100644 --- a/torchscale/architecture/decoder.py +++ b/torchscale/architecture/decoder.py @@ -189,9 +189,7 @@ class DecoderLayer(nn.Module): x = self.ffn(x) l_aux = None else: - x = x.transpose(0, 1) x, l_aux = self.moe_layer(x) - x = x.transpose(0, 1) if self.drop_path is not None: x = self.drop_path(x) @@ -391,26 +389,25 @@ class Decoder(nn.Module): x, _ = self.forward_embedding( prev_output_tokens, token_embeddings, incremental_state ) - x = x.transpose(0, 1) # relative position self_attn_rel_pos_bias = None slen = prev_output_tokens.size(1) if self.self_attn_relative_position is not None: self_attn_rel_pos_bias = self.self_attn_relative_position( - batch_size=x.size(1), qlen=slen, klen=slen + batch_size=x.size(0), qlen=slen, klen=slen ) if incremental_state is not None: - self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :] + self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :] cross_attn_rel_pos_bias = None if self.cross_attn_relative_position is not None: cross_attn_rel_pos_bias = self.cross_attn_relative_position( - batch_size=x.size(1), + batch_size=x.size(0), qlen=slen, - klen=encoder_out["encoder_out"].size(0), + klen=encoder_out["encoder_out"].size(1), ) if incremental_state is not None: - cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[:, -1:, :] + cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :] # decoder layers inner_states = [x] @@ -423,7 +420,7 @@ class Decoder(nn.Module): for idx, layer in enumerate(self.layers): if incremental_state is None: self_attn_mask = torch.triu( - torch.zeros([x.size(0), x.size(0)]) + torch.zeros([x.size(1), x.size(1)]) .float() .fill_(float("-inf")) .type_as(x), @@ -452,8 +449,6 @@ class Decoder(nn.Module): if self.layer_norm is not None: x = self.layer_norm(x) - x = x.transpose(0, 1) - if not features_only: x = self.output_layer(x) diff --git a/torchscale/architecture/encoder.py b/torchscale/architecture/encoder.py index dad97ef..a9c6b78 100644 --- a/torchscale/architecture/encoder.py +++ b/torchscale/architecture/encoder.py @@ -348,8 +348,6 @@ class Encoder(nn.Module): x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) - x = x.transpose(0, 1) - encoder_states = [] if return_all_hiddens: @@ -358,7 +356,7 @@ class Encoder(nn.Module): rel_pos_bias = None if self.relative_position is not None: rel_pos_bias = self.relative_position( - batch_size=x.size(1), qlen=x.size(0), klen=x.size(0) + batch_size=x.size(0), qlen=x.size(1), klen=x.size(1) ) l_aux = [] diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 3c67c28..b5f9c70 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -24,6 +24,7 @@ class MultiheadAttention(nn.Module): subln=False, ): super().__init__() + self.args = args self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads @@ -68,24 +69,26 @@ class MultiheadAttention(nn.Module): attn_mask=None, rel_pos=None, ): - tgt_len, bsz, embed_dim = query.size() + bsz, tgt_len, embed_dim = query.size() src_len = tgt_len assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" - assert list(query.size()) == [tgt_len, bsz, embed_dim] - src_len, key_bsz, _ = key.size() + key_bsz, src_len, _ = key.size() assert key_bsz == bsz, f"{query.size(), key.size()}" assert value is not None - assert src_len, bsz == value.shape[:2] + assert bsz, src_len == value.shape[:2] q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling - q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) - k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) - v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) + q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim) + k = k.reshape(bsz * self.num_heads, src_len, self.head_dim) + v = v.reshape(bsz * self.num_heads, src_len, self.head_dim) if incremental_state is not None: if "prev_key" in incremental_state: @@ -138,7 +141,7 @@ class MultiheadAttention(nn.Module): attn_probs = self.dropout_module(attn_weights) attn = torch.bmm(attn_probs, v) - attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1) if self.inner_attn_ln is not None: attn = self.inner_attn_ln(attn) diff --git a/torchscale/component/multiway_network.py b/torchscale/component/multiway_network.py index ea31320..d6a1ac0 100644 --- a/torchscale/component/multiway_network.py +++ b/torchscale/component/multiway_network.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn -def MultiwayWrapper(args, module, dim=0): +def MultiwayWrapper(args, module, dim=1): if args.multiway: return MultiwayNetwork(module, dim=dim) return module @@ -22,7 +22,7 @@ def set_split_position(position): class MultiwayNetwork(nn.Module): - def __init__(self, module, dim=0): + def __init__(self, module, dim=1): super().__init__() self.dim = dim self.A = module