diff --git a/vall_e/ar/model.py b/vall_e/ar/model.py index 479b87e..f017ba9 100644 --- a/vall_e/ar/model.py +++ b/vall_e/ar/model.py @@ -171,22 +171,11 @@ class Block(nn.Sequential): return x -class EmbeddingWithPE(nn.Module): - def __init__(self, num_tokens, token_dim): - super().__init__() - self.embedding = nn.Embedding(num_tokens, token_dim) - self.sin_emb = SinusodialEmbedding(token_dim) - - def forward(self, x_list: list[Tensor]) -> list[Tensor]: - if len(x_list) == 0: +class ListEmbedding(nn.Embedding): + def forward(self, x: list[Tensor]) -> list[Tensor]: + if len(x) == 0: return [] - - x = pad_sequence(x_list, batch_first=True) # b t - x = self.embedding(x) # b t d - x = self.sin_emb.add_pe(x) - x_list = [xi[:li] for xi, li in zip(x, map(len, x_list))] - - return x_list + return super().forward(torch.cat(x)).split([*map(len, x)]) def _join(x: tuple[Tensor], sep: Tensor): @@ -212,9 +201,10 @@ class VALLEAR(nn.Module): ): super().__init__() # Here, simply use num_tokens := max(num_text_tokens, num_prompt_tokens, num_output_tokens) - self.text_emb = EmbeddingWithPE(num_tokens, d_model) - self.prompt_emb = EmbeddingWithPE(num_tokens, d_model) - self.output_emb = EmbeddingWithPE(num_tokens, d_model) + self.text_emb = ListEmbedding(num_tokens, d_model) + self.prompt_emb = ListEmbedding(num_tokens, d_model) + self.output_emb = ListEmbedding(num_tokens, d_model) + self.sin_emb = SinusodialEmbedding(d_model) self.sep = nn.Parameter(torch.randn(d_model)) # start of sequence token self.blocks = nn.ModuleList( [Block(d_model, num_heads, dropout) for _ in range(num_layers)] @@ -254,6 +244,7 @@ class VALLEAR(nn.Module): ) x, m = _list_to_tensor(x_list) + x = self.sin_emb.add_pe(x) for block in self.blocks: x = block(x, m)