experimental weighting of prom/resp embeds
This commit is contained in:
parent
c690aa509d
commit
e799665759
|
@ -94,7 +94,7 @@ class AR_NAR(Base):
|
|||
lang_list: list[Tensor] | None = None,
|
||||
|
||||
max_steps: int = 1000,
|
||||
max_levels: int = 7,
|
||||
max_levels: int = 0,
|
||||
max_resp_context: int = -1,
|
||||
|
||||
sampling_temperature: float = 1.0,
|
||||
|
@ -166,7 +166,7 @@ class AR_NAR(Base):
|
|||
)
|
||||
# is NAR
|
||||
if max_levels == 0:
|
||||
max_levels = self.n_resp_levels
|
||||
max_levels = self.n_resp_levels - 1
|
||||
|
||||
prev_list = resps_list
|
||||
|
||||
|
|
|
@ -123,9 +123,10 @@ class MultiEmbedding(nn.Module):
|
|||
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
class AudioEmbedding(nn.Module):
|
||||
def __init__(self, l_tokens, token_dim):
|
||||
def __init__(self, l_tokens, token_dim, levels=None):
|
||||
super().__init__()
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||
|
||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
|
||||
res_list = []
|
||||
|
@ -133,13 +134,13 @@ class AudioEmbedding(nn.Module):
|
|||
for i, xi in enumerate(x_list):
|
||||
# prom
|
||||
if quant_levels is None and xi.shape[-1] > 1:
|
||||
x = sum( [ self.embeddings[k]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
|
||||
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||
# AR resp
|
||||
elif quant_levels is None or quant_levels[i] == 0:
|
||||
x = self.embeddings[0]( xi[:, 0] )
|
||||
# NAR resp
|
||||
else:
|
||||
x = sum( [ self.embeddings[k+1]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
|
||||
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||
res_list.append(x)
|
||||
|
||||
return res_list
|
||||
|
@ -255,9 +256,9 @@ class Base(nn.Module):
|
|||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||
else:
|
||||
# [1024] * 8
|
||||
self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model)
|
||||
self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model, self.n_prom_levels if self.version > 3 else None)
|
||||
# [1025] + [1024] * 8
|
||||
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model)
|
||||
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, self.n_resp_levels if self.version > 3 else None)
|
||||
|
||||
|
||||
if self.version >= 3:
|
||||
|
|
Loading…
Reference in New Issue
Block a user