Batch size first

This commit is contained in:
shumingma 2023-01-05 01:19:51 -08:00
parent 776b070d68
commit 1a5d2c26fe
5 changed files with 22 additions and 26 deletions

View File

@ -371,7 +371,7 @@ class MTEncoder(Encoder, FairseqEncoder):
)
def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order)
new_encoder_out = encoder_out["encoder_out"].index_select(0, new_order)
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(
0, new_order
)
@ -382,7 +382,7 @@ class MTEncoder(Encoder, FairseqEncoder):
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
encoder_states[idx] = state.index_select(0, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C

View File

@ -189,9 +189,7 @@ class DecoderLayer(nn.Module):
x = self.ffn(x)
l_aux = None
else:
x = x.transpose(0, 1)
x, l_aux = self.moe_layer(x)
x = x.transpose(0, 1)
if self.drop_path is not None:
x = self.drop_path(x)
@ -391,26 +389,25 @@ class Decoder(nn.Module):
x, _ = self.forward_embedding(
prev_output_tokens, token_embeddings, incremental_state
)
x = x.transpose(0, 1)
# relative position
self_attn_rel_pos_bias = None
slen = prev_output_tokens.size(1)
if self.self_attn_relative_position is not None:
self_attn_rel_pos_bias = self.self_attn_relative_position(
batch_size=x.size(1), qlen=slen, klen=slen
batch_size=x.size(0), qlen=slen, klen=slen
)
if incremental_state is not None:
self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :]
self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
cross_attn_rel_pos_bias = None
if self.cross_attn_relative_position is not None:
cross_attn_rel_pos_bias = self.cross_attn_relative_position(
batch_size=x.size(1),
batch_size=x.size(0),
qlen=slen,
klen=encoder_out["encoder_out"].size(0),
klen=encoder_out["encoder_out"].size(1),
)
if incremental_state is not None:
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[:, -1:, :]
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
# decoder layers
inner_states = [x]
@ -423,7 +420,7 @@ class Decoder(nn.Module):
for idx, layer in enumerate(self.layers):
if incremental_state is None:
self_attn_mask = torch.triu(
torch.zeros([x.size(0), x.size(0)])
torch.zeros([x.size(1), x.size(1)])
.float()
.fill_(float("-inf"))
.type_as(x),
@ -452,8 +449,6 @@ class Decoder(nn.Module):
if self.layer_norm is not None:
x = self.layer_norm(x)
x = x.transpose(0, 1)
if not features_only:
x = self.output_layer(x)

View File

@ -348,8 +348,6 @@ class Encoder(nn.Module):
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
x = x.transpose(0, 1)
encoder_states = []
if return_all_hiddens:
@ -358,7 +356,7 @@ class Encoder(nn.Module):
rel_pos_bias = None
if self.relative_position is not None:
rel_pos_bias = self.relative_position(
batch_size=x.size(1), qlen=x.size(0), klen=x.size(0)
batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)
)
l_aux = []

View File

@ -24,6 +24,7 @@ class MultiheadAttention(nn.Module):
subln=False,
):
super().__init__()
self.args = args
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
@ -68,24 +69,26 @@ class MultiheadAttention(nn.Module):
attn_mask=None,
rel_pos=None,
):
tgt_len, bsz, embed_dim = query.size()
bsz, tgt_len, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
assert list(query.size()) == [tgt_len, bsz, embed_dim]
src_len, key_bsz, _ = key.size()
key_bsz, src_len, _ = key.size()
assert key_bsz == bsz, f"{query.size(), key.size()}"
assert value is not None
assert src_len, bsz == value.shape[:2]
assert bsz, src_len == value.shape[:2]
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
k = k.reshape(bsz * self.num_heads, src_len, self.head_dim)
v = v.reshape(bsz * self.num_heads, src_len, self.head_dim)
if incremental_state is not None:
if "prev_key" in incremental_state:
@ -138,7 +141,7 @@ class MultiheadAttention(nn.Module):
attn_probs = self.dropout_module(attn_weights)
attn = torch.bmm(attn_probs, v)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1)
if self.inner_attn_ln is not None:
attn = self.inner_attn_ln(attn)

View File

@ -7,7 +7,7 @@ import torch
import torch.nn as nn
def MultiwayWrapper(args, module, dim=0):
def MultiwayWrapper(args, module, dim=1):
if args.multiway:
return MultiwayNetwork(module, dim=dim)
return module
@ -22,7 +22,7 @@ def set_split_position(position):
class MultiwayNetwork(nn.Module):
def __init__(self, module, dim=0):
def __init__(self, module, dim=1):
super().__init__()
self.dim = dim
self.A = module