Delay pe to cover sep
This commit is contained in:
parent
4b75b3adf3
commit
2296e2ea3c
|
@ -171,22 +171,11 @@ class Block(nn.Sequential):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingWithPE(nn.Module):
|
class ListEmbedding(nn.Embedding):
|
||||||
def __init__(self, num_tokens, token_dim):
|
def forward(self, x: list[Tensor]) -> list[Tensor]:
|
||||||
super().__init__()
|
if len(x) == 0:
|
||||||
self.embedding = nn.Embedding(num_tokens, token_dim)
|
|
||||||
self.sin_emb = SinusodialEmbedding(token_dim)
|
|
||||||
|
|
||||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
|
||||||
if len(x_list) == 0:
|
|
||||||
return []
|
return []
|
||||||
|
return super().forward(torch.cat(x)).split([*map(len, x)])
|
||||||
x = pad_sequence(x_list, batch_first=True) # b t
|
|
||||||
x = self.embedding(x) # b t d
|
|
||||||
x = self.sin_emb.add_pe(x)
|
|
||||||
x_list = [xi[:li] for xi, li in zip(x, map(len, x_list))]
|
|
||||||
|
|
||||||
return x_list
|
|
||||||
|
|
||||||
|
|
||||||
def _join(x: tuple[Tensor], sep: Tensor):
|
def _join(x: tuple[Tensor], sep: Tensor):
|
||||||
|
@ -212,9 +201,10 @@ class VALLEAR(nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Here, simply use num_tokens := max(num_text_tokens, num_prompt_tokens, num_output_tokens)
|
# Here, simply use num_tokens := max(num_text_tokens, num_prompt_tokens, num_output_tokens)
|
||||||
self.text_emb = EmbeddingWithPE(num_tokens, d_model)
|
self.text_emb = ListEmbedding(num_tokens, d_model)
|
||||||
self.prompt_emb = EmbeddingWithPE(num_tokens, d_model)
|
self.prompt_emb = ListEmbedding(num_tokens, d_model)
|
||||||
self.output_emb = EmbeddingWithPE(num_tokens, d_model)
|
self.output_emb = ListEmbedding(num_tokens, d_model)
|
||||||
|
self.sin_emb = SinusodialEmbedding(d_model)
|
||||||
self.sep = nn.Parameter(torch.randn(d_model)) # start of sequence token
|
self.sep = nn.Parameter(torch.randn(d_model)) # start of sequence token
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[Block(d_model, num_heads, dropout) for _ in range(num_layers)]
|
[Block(d_model, num_heads, dropout) for _ in range(num_layers)]
|
||||||
|
@ -254,6 +244,7 @@ class VALLEAR(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
x, m = _list_to_tensor(x_list)
|
x, m = _list_to_tensor(x_list)
|
||||||
|
x = self.sin_emb.add_pe(x)
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, m)
|
x = block(x, m)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user