Delay pe to cover sep
This commit is contained in:
parent
4b75b3adf3
commit
2296e2ea3c
|
@ -171,22 +171,11 @@ class Block(nn.Sequential):
|
|||
return x
|
||||
|
||||
|
||||
class EmbeddingWithPE(nn.Module):
|
||||
def __init__(self, num_tokens, token_dim):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(num_tokens, token_dim)
|
||||
self.sin_emb = SinusodialEmbedding(token_dim)
|
||||
|
||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
class ListEmbedding(nn.Embedding):
|
||||
def forward(self, x: list[Tensor]) -> list[Tensor]:
|
||||
if len(x) == 0:
|
||||
return []
|
||||
|
||||
x = pad_sequence(x_list, batch_first=True) # b t
|
||||
x = self.embedding(x) # b t d
|
||||
x = self.sin_emb.add_pe(x)
|
||||
x_list = [xi[:li] for xi, li in zip(x, map(len, x_list))]
|
||||
|
||||
return x_list
|
||||
return super().forward(torch.cat(x)).split([*map(len, x)])
|
||||
|
||||
|
||||
def _join(x: tuple[Tensor], sep: Tensor):
|
||||
|
@ -212,9 +201,10 @@ class VALLEAR(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
# Here, simply use num_tokens := max(num_text_tokens, num_prompt_tokens, num_output_tokens)
|
||||
self.text_emb = EmbeddingWithPE(num_tokens, d_model)
|
||||
self.prompt_emb = EmbeddingWithPE(num_tokens, d_model)
|
||||
self.output_emb = EmbeddingWithPE(num_tokens, d_model)
|
||||
self.text_emb = ListEmbedding(num_tokens, d_model)
|
||||
self.prompt_emb = ListEmbedding(num_tokens, d_model)
|
||||
self.output_emb = ListEmbedding(num_tokens, d_model)
|
||||
self.sin_emb = SinusodialEmbedding(d_model)
|
||||
self.sep = nn.Parameter(torch.randn(d_model)) # start of sequence token
|
||||
self.blocks = nn.ModuleList(
|
||||
[Block(d_model, num_heads, dropout) for _ in range(num_layers)]
|
||||
|
@ -254,6 +244,7 @@ class VALLEAR(nn.Module):
|
|||
)
|
||||
|
||||
x, m = _list_to_tensor(x_list)
|
||||
x = self.sin_emb.add_pe(x)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, m)
|
||||
|
|
Loading…
Reference in New Issue
Block a user