experimental weighting of prom/resp embeds

This commit is contained in:
mrq 2024-01-25 12:18:48 -06:00
parent c690aa509d
commit e799665759
2 changed files with 8 additions and 7 deletions

View File

@ -94,7 +94,7 @@ class AR_NAR(Base):
lang_list: list[Tensor] | None = None,
max_steps: int = 1000,
max_levels: int = 7,
max_levels: int = 0,
max_resp_context: int = -1,
sampling_temperature: float = 1.0,
@ -166,7 +166,7 @@ class AR_NAR(Base):
)
# is NAR
if max_levels == 0:
max_levels = self.n_resp_levels
max_levels = self.n_resp_levels - 1
prev_list = resps_list

View File

@ -123,9 +123,10 @@ class MultiEmbedding(nn.Module):
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
class AudioEmbedding(nn.Module):
def __init__(self, l_tokens, token_dim):
def __init__(self, l_tokens, token_dim, levels=None):
super().__init__()
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
res_list = []
@ -133,13 +134,13 @@ class AudioEmbedding(nn.Module):
for i, xi in enumerate(x_list):
# prom
if quant_levels is None and xi.shape[-1] > 1:
x = sum( [ self.embeddings[k]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
# AR resp
elif quant_levels is None or quant_levels[i] == 0:
x = self.embeddings[0]( xi[:, 0] )
# NAR resp
else:
x = sum( [ self.embeddings[k+1]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
res_list.append(x)
return res_list
@ -255,9 +256,9 @@ class Base(nn.Module):
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
else:
# [1024] * 8
self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model)
self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model, self.n_prom_levels if self.version > 3 else None)
# [1025] + [1024] * 8
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model)
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, self.n_resp_levels if self.version > 3 else None)
if self.version >= 3: