added homebrewed per-RVQ-bin embedding solutions
This commit is contained in:
parent
e7a67410d1
commit
b2c2dec291
|
@ -54,7 +54,7 @@ class AR(Base):
|
|||
return False
|
||||
|
||||
@property
|
||||
def dual(self) -> bool:
|
||||
def monolithic(self) -> bool:
|
||||
return False
|
||||
|
||||
def _prune(self, l: Tensor):
|
||||
|
|
|
@ -52,7 +52,7 @@ class AR_NAR(Base):
|
|||
return False
|
||||
|
||||
@property
|
||||
def dual(self) -> bool:
|
||||
def monolithic(self) -> bool:
|
||||
return True
|
||||
|
||||
def _prune(self, l: Tensor):
|
||||
|
|
|
@ -78,7 +78,8 @@ class MultiEmbedding(nn.Embedding):
|
|||
|
||||
for xi in x_list:
|
||||
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
||||
xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k
|
||||
wi = w.shape[0] - xi.shape[1]
|
||||
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
|
||||
padded_x_list.append(xi.to(w))
|
||||
|
||||
x = torch.cat(padded_x_list) # n l k
|
||||
|
@ -87,7 +88,8 @@ class MultiEmbedding(nn.Embedding):
|
|||
x_list = x.split([*map(len, x_list)])
|
||||
|
||||
return x_list
|
||||
"""
|
||||
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
class PromEmbedding(nn.Module):
|
||||
def __init__(self, n_levels, n_tokens, token_dim):
|
||||
super().__init__()
|
||||
|
@ -95,11 +97,9 @@ class PromEmbedding(nn.Module):
|
|||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
||||
|
||||
def forward(self, x_list: list[Tensor] ) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
|
||||
return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ]
|
||||
|
||||
# Embedding that selects which embedding based on a quant_level tensor for a given batch
|
||||
class RespEmbedding(nn.Module):
|
||||
def __init__(self, n_levels, n_tokens, token_dim):
|
||||
super().__init__()
|
||||
|
@ -107,11 +107,8 @@ class RespEmbedding(nn.Module):
|
|||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
||||
|
||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
res = [ self.embeddings[quant_levels[i] if quant_levels is not None else 0](xi) for i, xi in enumerate(x_list) ]
|
||||
return res
|
||||
"""
|
||||
return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ]
|
||||
|
||||
class Base(nn.Module):
|
||||
@property
|
||||
def causal(self) -> bool:
|
||||
|
@ -150,12 +147,12 @@ class Base(nn.Module):
|
|||
return False
|
||||
|
||||
@property
|
||||
def dual(self) -> bool:
|
||||
def monolithic(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def n_embeddings(self):
|
||||
return self.n_resp_levels if self.dual else 1
|
||||
return self.n_resp_levels if self.monolithic else 1
|
||||
|
||||
@property
|
||||
def stop_token(self):
|
||||
|
@ -200,17 +197,19 @@ class Base(nn.Module):
|
|||
|
||||
self.text_emb = Embedding(n_tokens, d_model)
|
||||
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
if self.n_embeddings == 1:
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
|
||||
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
|
||||
if self.n_embeddings == self.n_prom_levels:
|
||||
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
else:
|
||||
self.resps_emb = nn.ModuleList([ MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(self.n_embeddings) ])
|
||||
"""
|
||||
if self.n_embeddings == 1:
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
|
||||
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested
|
||||
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
|
||||
if self.n_embeddings > 1:
|
||||
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
|
||||
else:
|
||||
self.resps_emb = RespEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
"""
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
@ -295,21 +294,12 @@ class Base(nn.Module):
|
|||
|
||||
state: dict | None = None,
|
||||
):
|
||||
if self.n_embeddings == 1:
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
self.resps_emb(resps_list),
|
||||
sep=self.sep,
|
||||
)
|
||||
else:
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
self.resps_emb[0 if quant_levels is None else 1](resps_list),
|
||||
#self.resps_emb(resps_list, quant_levels),
|
||||
sep=self.sep,
|
||||
)
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
self.resps_emb(resps_list, quant_levels),
|
||||
sep=self.sep,
|
||||
)
|
||||
|
||||
x, m = list_to_tensor(x_list)
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ class NAR(Base):
|
|||
return False
|
||||
|
||||
@property
|
||||
def dual(self) -> bool:
|
||||
def monolithic(self) -> bool:
|
||||
return False
|
||||
|
||||
def forward(
|
||||
|
|
Loading…
Reference in New Issue
Block a user