added homebrewed per-RVQ-bin embedding solutions
This commit is contained in:
parent
e7a67410d1
commit
b2c2dec291
@ -54,7 +54,7 @@ class AR(Base):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dual(self) -> bool:
|
def monolithic(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _prune(self, l: Tensor):
|
def _prune(self, l: Tensor):
|
||||||
|
|||||||
@ -52,7 +52,7 @@ class AR_NAR(Base):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dual(self) -> bool:
|
def monolithic(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _prune(self, l: Tensor):
|
def _prune(self, l: Tensor):
|
||||||
|
|||||||
@ -78,7 +78,8 @@ class MultiEmbedding(nn.Embedding):
|
|||||||
|
|
||||||
for xi in x_list:
|
for xi in x_list:
|
||||||
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
||||||
xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k
|
wi = w.shape[0] - xi.shape[1]
|
||||||
|
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
|
||||||
padded_x_list.append(xi.to(w))
|
padded_x_list.append(xi.to(w))
|
||||||
|
|
||||||
x = torch.cat(padded_x_list) # n l k
|
x = torch.cat(padded_x_list) # n l k
|
||||||
@ -87,7 +88,8 @@ class MultiEmbedding(nn.Embedding):
|
|||||||
x_list = x.split([*map(len, x_list)])
|
x_list = x.split([*map(len, x_list)])
|
||||||
|
|
||||||
return x_list
|
return x_list
|
||||||
"""
|
|
||||||
|
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||||
class PromEmbedding(nn.Module):
|
class PromEmbedding(nn.Module):
|
||||||
def __init__(self, n_levels, n_tokens, token_dim):
|
def __init__(self, n_levels, n_tokens, token_dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -95,11 +97,9 @@ class PromEmbedding(nn.Module):
|
|||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
||||||
|
|
||||||
def forward(self, x_list: list[Tensor] ) -> list[Tensor]:
|
def forward(self, x_list: list[Tensor] ) -> list[Tensor]:
|
||||||
if len(x_list) == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ]
|
return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ]
|
||||||
|
|
||||||
|
# Embedding that selects which embedding based on a quant_level tensor for a given batch
|
||||||
class RespEmbedding(nn.Module):
|
class RespEmbedding(nn.Module):
|
||||||
def __init__(self, n_levels, n_tokens, token_dim):
|
def __init__(self, n_levels, n_tokens, token_dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -107,11 +107,8 @@ class RespEmbedding(nn.Module):
|
|||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
||||||
|
|
||||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
|
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
|
||||||
if len(x_list) == 0:
|
return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ]
|
||||||
return []
|
|
||||||
res = [ self.embeddings[quant_levels[i] if quant_levels is not None else 0](xi) for i, xi in enumerate(x_list) ]
|
|
||||||
return res
|
|
||||||
"""
|
|
||||||
class Base(nn.Module):
|
class Base(nn.Module):
|
||||||
@property
|
@property
|
||||||
def causal(self) -> bool:
|
def causal(self) -> bool:
|
||||||
@ -150,12 +147,12 @@ class Base(nn.Module):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dual(self) -> bool:
|
def monolithic(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_embeddings(self):
|
def n_embeddings(self):
|
||||||
return self.n_resp_levels if self.dual else 1
|
return self.n_resp_levels if self.monolithic else 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stop_token(self):
|
def stop_token(self):
|
||||||
@ -200,17 +197,19 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
self.text_emb = Embedding(n_tokens, d_model)
|
self.text_emb = Embedding(n_tokens, d_model)
|
||||||
|
|
||||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
|
||||||
if self.n_embeddings == 1:
|
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
|
||||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
if self.n_embeddings == self.n_prom_levels:
|
||||||
|
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||||
else:
|
else:
|
||||||
self.resps_emb = nn.ModuleList([ MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(self.n_embeddings) ])
|
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||||
"""
|
|
||||||
if self.n_embeddings == 1:
|
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested
|
||||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
|
||||||
|
if self.n_embeddings > 1:
|
||||||
|
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
|
||||||
else:
|
else:
|
||||||
self.resps_emb = RespEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||||
"""
|
|
||||||
|
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
@ -295,21 +294,12 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
state: dict | None = None,
|
state: dict | None = None,
|
||||||
):
|
):
|
||||||
if self.n_embeddings == 1:
|
x_list = self._samplewise_merge_tensors(
|
||||||
x_list = self._samplewise_merge_tensors(
|
self.text_emb(text_list),
|
||||||
self.text_emb(text_list),
|
self.proms_emb(proms_list),
|
||||||
self.proms_emb(proms_list),
|
self.resps_emb(resps_list, quant_levels),
|
||||||
self.resps_emb(resps_list),
|
sep=self.sep,
|
||||||
sep=self.sep,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
x_list = self._samplewise_merge_tensors(
|
|
||||||
self.text_emb(text_list),
|
|
||||||
self.proms_emb(proms_list),
|
|
||||||
self.resps_emb[0 if quant_levels is None else 1](resps_list),
|
|
||||||
#self.resps_emb(resps_list, quant_levels),
|
|
||||||
sep=self.sep,
|
|
||||||
)
|
|
||||||
|
|
||||||
x, m = list_to_tensor(x_list)
|
x, m = list_to_tensor(x_list)
|
||||||
|
|
||||||
|
|||||||
@ -48,7 +48,7 @@ class NAR(Base):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dual(self) -> bool:
|
def monolithic(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user