Merge pull request #12 from microsoft/bsz

Batch size first
This commit is contained in:
Shuming Ma 2023-01-16 20:07:52 +08:00 committed by GitHub
commit 82f140a6c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 22 additions and 26 deletions

View File

@ -371,7 +371,7 @@ class MTEncoder(Encoder, FairseqEncoder):
) )
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order) new_encoder_out = encoder_out["encoder_out"].index_select(0, new_order)
new_encoder_embedding = encoder_out["encoder_embedding"].index_select( new_encoder_embedding = encoder_out["encoder_embedding"].index_select(
0, new_order 0, new_order
) )
@ -382,7 +382,7 @@ class MTEncoder(Encoder, FairseqEncoder):
encoder_states = encoder_out["encoder_states"] encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0: if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states): for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order) encoder_states[idx] = state.index_select(0, new_order)
return { return {
"encoder_out": new_encoder_out, # T x B x C "encoder_out": new_encoder_out, # T x B x C

View File

@ -189,9 +189,7 @@ class DecoderLayer(nn.Module):
x = self.ffn(x) x = self.ffn(x)
l_aux = None l_aux = None
else: else:
x = x.transpose(0, 1)
x, l_aux = self.moe_layer(x) x, l_aux = self.moe_layer(x)
x = x.transpose(0, 1)
if self.drop_path is not None: if self.drop_path is not None:
x = self.drop_path(x) x = self.drop_path(x)
@ -391,26 +389,25 @@ class Decoder(nn.Module):
x, _ = self.forward_embedding( x, _ = self.forward_embedding(
prev_output_tokens, token_embeddings, incremental_state prev_output_tokens, token_embeddings, incremental_state
) )
x = x.transpose(0, 1)
# relative position # relative position
self_attn_rel_pos_bias = None self_attn_rel_pos_bias = None
slen = prev_output_tokens.size(1) slen = prev_output_tokens.size(1)
if self.self_attn_relative_position is not None: if self.self_attn_relative_position is not None:
self_attn_rel_pos_bias = self.self_attn_relative_position( self_attn_rel_pos_bias = self.self_attn_relative_position(
batch_size=x.size(1), qlen=slen, klen=slen batch_size=x.size(0), qlen=slen, klen=slen
) )
if incremental_state is not None: if incremental_state is not None:
self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :] self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
cross_attn_rel_pos_bias = None cross_attn_rel_pos_bias = None
if self.cross_attn_relative_position is not None: if self.cross_attn_relative_position is not None:
cross_attn_rel_pos_bias = self.cross_attn_relative_position( cross_attn_rel_pos_bias = self.cross_attn_relative_position(
batch_size=x.size(1), batch_size=x.size(0),
qlen=slen, qlen=slen,
klen=encoder_out["encoder_out"].size(0), klen=encoder_out["encoder_out"].size(1),
) )
if incremental_state is not None: if incremental_state is not None:
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[:, -1:, :] cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
# decoder layers # decoder layers
inner_states = [x] inner_states = [x]
@ -423,7 +420,7 @@ class Decoder(nn.Module):
for idx, layer in enumerate(self.layers): for idx, layer in enumerate(self.layers):
if incremental_state is None: if incremental_state is None:
self_attn_mask = torch.triu( self_attn_mask = torch.triu(
torch.zeros([x.size(0), x.size(0)]) torch.zeros([x.size(1), x.size(1)])
.float() .float()
.fill_(float("-inf")) .fill_(float("-inf"))
.type_as(x), .type_as(x),
@ -452,8 +449,6 @@ class Decoder(nn.Module):
if self.layer_norm is not None: if self.layer_norm is not None:
x = self.layer_norm(x) x = self.layer_norm(x)
x = x.transpose(0, 1)
if not features_only: if not features_only:
x = self.output_layer(x) x = self.output_layer(x)

View File

@ -348,8 +348,6 @@ class Encoder(nn.Module):
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
x = x.transpose(0, 1)
encoder_states = [] encoder_states = []
if return_all_hiddens: if return_all_hiddens:
@ -358,7 +356,7 @@ class Encoder(nn.Module):
rel_pos_bias = None rel_pos_bias = None
if self.relative_position is not None: if self.relative_position is not None:
rel_pos_bias = self.relative_position( rel_pos_bias = self.relative_position(
batch_size=x.size(1), qlen=x.size(0), klen=x.size(0) batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)
) )
l_aux = [] l_aux = []

View File

@ -24,6 +24,7 @@ class MultiheadAttention(nn.Module):
subln=False, subln=False,
): ):
super().__init__() super().__init__()
self.args = args
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
@ -68,24 +69,26 @@ class MultiheadAttention(nn.Module):
attn_mask=None, attn_mask=None,
rel_pos=None, rel_pos=None,
): ):
tgt_len, bsz, embed_dim = query.size() bsz, tgt_len, embed_dim = query.size()
src_len = tgt_len src_len = tgt_len
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
assert list(query.size()) == [tgt_len, bsz, embed_dim]
src_len, key_bsz, _ = key.size() key_bsz, src_len, _ = key.size()
assert key_bsz == bsz, f"{query.size(), key.size()}" assert key_bsz == bsz, f"{query.size(), key.size()}"
assert value is not None assert value is not None
assert src_len, bsz == value.shape[:2] assert bsz, src_len == value.shape[:2]
q = self.q_proj(query) q = self.q_proj(query)
k = self.k_proj(key) k = self.k_proj(key)
v = self.v_proj(value) v = self.v_proj(value)
q *= self.scaling q *= self.scaling
q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
k = k.reshape(bsz * self.num_heads, src_len, self.head_dim)
v = v.reshape(bsz * self.num_heads, src_len, self.head_dim)
if incremental_state is not None: if incremental_state is not None:
if "prev_key" in incremental_state: if "prev_key" in incremental_state:
@ -138,7 +141,7 @@ class MultiheadAttention(nn.Module):
attn_probs = self.dropout_module(attn_weights) attn_probs = self.dropout_module(attn_weights)
attn = torch.bmm(attn_probs, v) attn = torch.bmm(attn_probs, v)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1)
if self.inner_attn_ln is not None: if self.inner_attn_ln is not None:
attn = self.inner_attn_ln(attn) attn = self.inner_attn_ln(attn)

View File

@ -7,7 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
def MultiwayWrapper(args, module, dim=0): def MultiwayWrapper(args, module, dim=1):
if args.multiway: if args.multiway:
return MultiwayNetwork(module, dim=dim) return MultiwayNetwork(module, dim=dim)
return module return module
@ -22,7 +22,7 @@ def set_split_position(position):
class MultiwayNetwork(nn.Module): class MultiwayNetwork(nn.Module):
def __init__(self, module, dim=0): def __init__(self, module, dim=1):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.A = module self.A = module