commit
82f140a6c4
|
@ -371,7 +371,7 @@ class MTEncoder(Encoder, FairseqEncoder):
|
||||||
)
|
)
|
||||||
|
|
||||||
def reorder_encoder_out(self, encoder_out, new_order):
|
def reorder_encoder_out(self, encoder_out, new_order):
|
||||||
new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order)
|
new_encoder_out = encoder_out["encoder_out"].index_select(0, new_order)
|
||||||
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(
|
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(
|
||||||
0, new_order
|
0, new_order
|
||||||
)
|
)
|
||||||
|
@ -382,7 +382,7 @@ class MTEncoder(Encoder, FairseqEncoder):
|
||||||
encoder_states = encoder_out["encoder_states"]
|
encoder_states = encoder_out["encoder_states"]
|
||||||
if len(encoder_states) > 0:
|
if len(encoder_states) > 0:
|
||||||
for idx, state in enumerate(encoder_states):
|
for idx, state in enumerate(encoder_states):
|
||||||
encoder_states[idx] = state.index_select(1, new_order)
|
encoder_states[idx] = state.index_select(0, new_order)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"encoder_out": new_encoder_out, # T x B x C
|
"encoder_out": new_encoder_out, # T x B x C
|
||||||
|
|
|
@ -189,9 +189,7 @@ class DecoderLayer(nn.Module):
|
||||||
x = self.ffn(x)
|
x = self.ffn(x)
|
||||||
l_aux = None
|
l_aux = None
|
||||||
else:
|
else:
|
||||||
x = x.transpose(0, 1)
|
|
||||||
x, l_aux = self.moe_layer(x)
|
x, l_aux = self.moe_layer(x)
|
||||||
x = x.transpose(0, 1)
|
|
||||||
|
|
||||||
if self.drop_path is not None:
|
if self.drop_path is not None:
|
||||||
x = self.drop_path(x)
|
x = self.drop_path(x)
|
||||||
|
@ -391,26 +389,25 @@ class Decoder(nn.Module):
|
||||||
x, _ = self.forward_embedding(
|
x, _ = self.forward_embedding(
|
||||||
prev_output_tokens, token_embeddings, incremental_state
|
prev_output_tokens, token_embeddings, incremental_state
|
||||||
)
|
)
|
||||||
x = x.transpose(0, 1)
|
|
||||||
|
|
||||||
# relative position
|
# relative position
|
||||||
self_attn_rel_pos_bias = None
|
self_attn_rel_pos_bias = None
|
||||||
slen = prev_output_tokens.size(1)
|
slen = prev_output_tokens.size(1)
|
||||||
if self.self_attn_relative_position is not None:
|
if self.self_attn_relative_position is not None:
|
||||||
self_attn_rel_pos_bias = self.self_attn_relative_position(
|
self_attn_rel_pos_bias = self.self_attn_relative_position(
|
||||||
batch_size=x.size(1), qlen=slen, klen=slen
|
batch_size=x.size(0), qlen=slen, klen=slen
|
||||||
)
|
)
|
||||||
if incremental_state is not None:
|
if incremental_state is not None:
|
||||||
self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :]
|
self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
|
||||||
cross_attn_rel_pos_bias = None
|
cross_attn_rel_pos_bias = None
|
||||||
if self.cross_attn_relative_position is not None:
|
if self.cross_attn_relative_position is not None:
|
||||||
cross_attn_rel_pos_bias = self.cross_attn_relative_position(
|
cross_attn_rel_pos_bias = self.cross_attn_relative_position(
|
||||||
batch_size=x.size(1),
|
batch_size=x.size(0),
|
||||||
qlen=slen,
|
qlen=slen,
|
||||||
klen=encoder_out["encoder_out"].size(0),
|
klen=encoder_out["encoder_out"].size(1),
|
||||||
)
|
)
|
||||||
if incremental_state is not None:
|
if incremental_state is not None:
|
||||||
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[:, -1:, :]
|
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
inner_states = [x]
|
inner_states = [x]
|
||||||
|
@ -423,7 +420,7 @@ class Decoder(nn.Module):
|
||||||
for idx, layer in enumerate(self.layers):
|
for idx, layer in enumerate(self.layers):
|
||||||
if incremental_state is None:
|
if incremental_state is None:
|
||||||
self_attn_mask = torch.triu(
|
self_attn_mask = torch.triu(
|
||||||
torch.zeros([x.size(0), x.size(0)])
|
torch.zeros([x.size(1), x.size(1)])
|
||||||
.float()
|
.float()
|
||||||
.fill_(float("-inf"))
|
.fill_(float("-inf"))
|
||||||
.type_as(x),
|
.type_as(x),
|
||||||
|
@ -452,8 +449,6 @@ class Decoder(nn.Module):
|
||||||
if self.layer_norm is not None:
|
if self.layer_norm is not None:
|
||||||
x = self.layer_norm(x)
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
x = x.transpose(0, 1)
|
|
||||||
|
|
||||||
if not features_only:
|
if not features_only:
|
||||||
x = self.output_layer(x)
|
x = self.output_layer(x)
|
||||||
|
|
||||||
|
|
|
@ -348,8 +348,6 @@ class Encoder(nn.Module):
|
||||||
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
|
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
|
||||||
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
||||||
|
|
||||||
x = x.transpose(0, 1)
|
|
||||||
|
|
||||||
encoder_states = []
|
encoder_states = []
|
||||||
|
|
||||||
if return_all_hiddens:
|
if return_all_hiddens:
|
||||||
|
@ -358,7 +356,7 @@ class Encoder(nn.Module):
|
||||||
rel_pos_bias = None
|
rel_pos_bias = None
|
||||||
if self.relative_position is not None:
|
if self.relative_position is not None:
|
||||||
rel_pos_bias = self.relative_position(
|
rel_pos_bias = self.relative_position(
|
||||||
batch_size=x.size(1), qlen=x.size(0), klen=x.size(0)
|
batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
l_aux = []
|
l_aux = []
|
||||||
|
|
|
@ -24,6 +24,7 @@ class MultiheadAttention(nn.Module):
|
||||||
subln=False,
|
subln=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_dim = embed_dim // num_heads
|
self.head_dim = embed_dim // num_heads
|
||||||
|
@ -68,24 +69,26 @@ class MultiheadAttention(nn.Module):
|
||||||
attn_mask=None,
|
attn_mask=None,
|
||||||
rel_pos=None,
|
rel_pos=None,
|
||||||
):
|
):
|
||||||
tgt_len, bsz, embed_dim = query.size()
|
bsz, tgt_len, embed_dim = query.size()
|
||||||
src_len = tgt_len
|
src_len = tgt_len
|
||||||
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
|
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
|
||||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
|
||||||
|
|
||||||
src_len, key_bsz, _ = key.size()
|
key_bsz, src_len, _ = key.size()
|
||||||
assert key_bsz == bsz, f"{query.size(), key.size()}"
|
assert key_bsz == bsz, f"{query.size(), key.size()}"
|
||||||
assert value is not None
|
assert value is not None
|
||||||
assert src_len, bsz == value.shape[:2]
|
assert bsz, src_len == value.shape[:2]
|
||||||
|
|
||||||
q = self.q_proj(query)
|
q = self.q_proj(query)
|
||||||
k = self.k_proj(key)
|
k = self.k_proj(key)
|
||||||
v = self.v_proj(value)
|
v = self.v_proj(value)
|
||||||
q *= self.scaling
|
q *= self.scaling
|
||||||
|
|
||||||
q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
|
||||||
|
k = k.reshape(bsz * self.num_heads, src_len, self.head_dim)
|
||||||
|
v = v.reshape(bsz * self.num_heads, src_len, self.head_dim)
|
||||||
|
|
||||||
if incremental_state is not None:
|
if incremental_state is not None:
|
||||||
if "prev_key" in incremental_state:
|
if "prev_key" in incremental_state:
|
||||||
|
@ -138,7 +141,7 @@ class MultiheadAttention(nn.Module):
|
||||||
attn_probs = self.dropout_module(attn_weights)
|
attn_probs = self.dropout_module(attn_weights)
|
||||||
|
|
||||||
attn = torch.bmm(attn_probs, v)
|
attn = torch.bmm(attn_probs, v)
|
||||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1)
|
||||||
|
|
||||||
if self.inner_attn_ln is not None:
|
if self.inner_attn_ln is not None:
|
||||||
attn = self.inner_attn_ln(attn)
|
attn = self.inner_attn_ln(attn)
|
||||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def MultiwayWrapper(args, module, dim=0):
|
def MultiwayWrapper(args, module, dim=1):
|
||||||
if args.multiway:
|
if args.multiway:
|
||||||
return MultiwayNetwork(module, dim=dim)
|
return MultiwayNetwork(module, dim=dim)
|
||||||
return module
|
return module
|
||||||
|
@ -22,7 +22,7 @@ def set_split_position(position):
|
||||||
|
|
||||||
|
|
||||||
class MultiwayNetwork(nn.Module):
|
class MultiwayNetwork(nn.Module):
|
||||||
def __init__(self, module, dim=0):
|
def __init__(self, module, dim=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.A = module
|
self.A = module
|
||||||
|
|
Loading…
Reference in New Issue
Block a user