added homebrewed per-RVQ-bin embedding solutions

This commit is contained in:
mrq 2023-09-07 16:48:02 -05:00
parent e7a67410d1
commit b2c2dec291
4 changed files with 29 additions and 39 deletions

View File

@ -54,7 +54,7 @@ class AR(Base):
return False
@property
def dual(self) -> bool:
def monolithic(self) -> bool:
return False
def _prune(self, l: Tensor):

View File

@ -52,7 +52,7 @@ class AR_NAR(Base):
return False
@property
def dual(self) -> bool:
def monolithic(self) -> bool:
return True
def _prune(self, l: Tensor):

View File

@ -78,7 +78,8 @@ class MultiEmbedding(nn.Embedding):
for xi in x_list:
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k
wi = w.shape[0] - xi.shape[1]
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
padded_x_list.append(xi.to(w))
x = torch.cat(padded_x_list) # n l k
@ -87,7 +88,8 @@ class MultiEmbedding(nn.Embedding):
x_list = x.split([*map(len, x_list)])
return x_list
"""
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
class PromEmbedding(nn.Module):
def __init__(self, n_levels, n_tokens, token_dim):
super().__init__()
@ -95,11 +97,9 @@ class PromEmbedding(nn.Module):
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
def forward(self, x_list: list[Tensor] ) -> list[Tensor]:
if len(x_list) == 0:
return []
return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ]
# Embedding that selects which embedding based on a quant_level tensor for a given batch
class RespEmbedding(nn.Module):
def __init__(self, n_levels, n_tokens, token_dim):
super().__init__()
@ -107,11 +107,8 @@ class RespEmbedding(nn.Module):
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
if len(x_list) == 0:
return []
res = [ self.embeddings[quant_levels[i] if quant_levels is not None else 0](xi) for i, xi in enumerate(x_list) ]
return res
"""
return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ]
class Base(nn.Module):
@property
def causal(self) -> bool:
@ -150,12 +147,12 @@ class Base(nn.Module):
return False
@property
def dual(self) -> bool:
def monolithic(self) -> bool:
return False
@property
def n_embeddings(self):
return self.n_resp_levels if self.dual else 1
return self.n_resp_levels if self.monolithic else 1
@property
def stop_token(self):
@ -200,17 +197,19 @@ class Base(nn.Module):
self.text_emb = Embedding(n_tokens, d_model)
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
if self.n_embeddings == 1:
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
if self.n_embeddings == self.n_prom_levels:
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
else:
self.resps_emb = nn.ModuleList([ MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(self.n_embeddings) ])
"""
if self.n_embeddings == 1:
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
if self.n_embeddings > 1:
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
else:
self.resps_emb = RespEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
"""
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
self.sep = nn.Parameter(torch.randn(d_model))
@ -295,21 +294,12 @@ class Base(nn.Module):
state: dict | None = None,
):
if self.n_embeddings == 1:
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb(resps_list),
sep=self.sep,
)
else:
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb[0 if quant_levels is None else 1](resps_list),
#self.resps_emb(resps_list, quant_levels),
sep=self.sep,
)
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb(resps_list, quant_levels),
sep=self.sep,
)
x, m = list_to_tensor(x_list)

View File

@ -48,7 +48,7 @@ class NAR(Base):
return False
@property
def dual(self) -> bool:
def monolithic(self) -> bool:
return False
def forward(