This commit is contained in:
mrq 2024-06-06 19:41:26 -05:00
parent ee25d2e62e
commit 516b0894d7

View File

@ -105,9 +105,7 @@ class AudioEmbedding_Old(nn.Module):
self,
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
mode: "old", # old | prom | resp
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
):
super().__init__()
# array of embeddings
@ -115,31 +113,19 @@ class AudioEmbedding_Old(nn.Module):
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None and mode == "old" else None
#
self.mode = mode
#
self.sums = sums
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
# prom
if quant_levels is None and xi.shape[-1] > 1:
if self.sums:
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
else:
k = 0 # only use the most significant RVQ bin level for the input prom
x = self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1)
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
# AR resp
elif quant_levels is None or quant_levels == 0:
x = self.embeddings[0]( xi if len(xi.shape) == 1 else xi[:, 0] )
# NAR resp
else:
if self.sums:
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
else:
k = xi.shape[-1] - 1 # only use the previous RVQ bin level for the current resp embedding
x = self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1)
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
return x
class AudioEmbedding(nn.Module):
@ -165,16 +151,16 @@ class AudioEmbedding(nn.Module):
if quant_level is None:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
# embeddings for AR/NAR cannot be shared
offset = 0 if self.mode == "prom" or quant_level == 0 else 1
# jank but needed
if xi.dim() == 1:
x = self.embeddings[quant_level]( xi )
elif self.sums and quant_level > 0:
return self.embeddings[quant_level]( xi )
offset = 0 if self.mode == "prom" else 1
if self.sums and quant_level > 0:
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
else:
k = quant_level
x = self.embeddings[k + offset]( xi[:, k] )
k = quant_level - 1
x = self.embeddings[k + offset]( xi if xi.dim() == 1 else xi[:, k] )
return x
@ -302,15 +288,11 @@ class Base(nn.Module):
self.proms_emb = AudioEmbedding_Old(
[n_prom_tokens] * self.n_prom_levels, d_model,
levels=self.n_prom_levels if self.version > 3 else None,
mode="prom" if self.version >= 5 else "old",
sums=self.config.audio_embedding_sums if self.config is not None else True,
)
# [1024 + STOP] + [1024] * 8
self.resps_emb = AudioEmbedding_Old(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
levels=self.n_resp_levels if self.version > 3 else None,
mode="resp" if self.version >= 5 else "old",
sums=self.config.audio_embedding_sums if self.config is not None else True
)
else:
self.proms_emb = AudioEmbedding(
@ -686,7 +668,7 @@ class Base(nn.Module):
elif name == "lang" and self.langs_emb is not None:
embedding = self.langs_emb( input )
elif name == "prom":
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level] )
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level if self.version >= 5 else None )
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":