Delay pe to cover sep

This commit is contained in:
enhuiz 2023-01-11 23:11:24 +08:00
parent 4b75b3adf3
commit 2296e2ea3c

View File

@ -171,22 +171,11 @@ class Block(nn.Sequential):
return x
class EmbeddingWithPE(nn.Module):
def __init__(self, num_tokens, token_dim):
super().__init__()
self.embedding = nn.Embedding(num_tokens, token_dim)
self.sin_emb = SinusodialEmbedding(token_dim)
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
if len(x_list) == 0:
class ListEmbedding(nn.Embedding):
def forward(self, x: list[Tensor]) -> list[Tensor]:
if len(x) == 0:
return []
x = pad_sequence(x_list, batch_first=True) # b t
x = self.embedding(x) # b t d
x = self.sin_emb.add_pe(x)
x_list = [xi[:li] for xi, li in zip(x, map(len, x_list))]
return x_list
return super().forward(torch.cat(x)).split([*map(len, x)])
def _join(x: tuple[Tensor], sep: Tensor):
@ -212,9 +201,10 @@ class VALLEAR(nn.Module):
):
super().__init__()
# Here, simply use num_tokens := max(num_text_tokens, num_prompt_tokens, num_output_tokens)
self.text_emb = EmbeddingWithPE(num_tokens, d_model)
self.prompt_emb = EmbeddingWithPE(num_tokens, d_model)
self.output_emb = EmbeddingWithPE(num_tokens, d_model)
self.text_emb = ListEmbedding(num_tokens, d_model)
self.prompt_emb = ListEmbedding(num_tokens, d_model)
self.output_emb = ListEmbedding(num_tokens, d_model)
self.sin_emb = SinusodialEmbedding(d_model)
self.sep = nn.Parameter(torch.randn(d_model)) # start of sequence token
self.blocks = nn.ModuleList(
[Block(d_model, num_heads, dropout) for _ in range(num_layers)]
@ -254,6 +244,7 @@ class VALLEAR(nn.Module):
)
x, m = _list_to_tensor(x_list)
x = self.sin_emb.add_pe(x)
for block in self.blocks:
x = block(x, m)