This commit is contained in:
mrq 2024-06-06 19:41:26 -05:00
parent ee25d2e62e
commit 516b0894d7

View File

@ -105,9 +105,7 @@ class AudioEmbedding_Old(nn.Module):
self, self,
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token) l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding token_dim: int, # dimensionality of the embedding
mode: "old", # old | prom | resp
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics) levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
): ):
super().__init__() super().__init__()
# array of embeddings # array of embeddings
@ -115,30 +113,18 @@ class AudioEmbedding_Old(nn.Module):
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR # resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this) # weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None and mode == "old" else None self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
#
self.mode = mode
#
self.sums = sums
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor: def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
# prom # prom
if quant_levels is None and xi.shape[-1] > 1: if quant_levels is None and xi.shape[-1] > 1:
if self.sums: x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
else:
k = 0 # only use the most significant RVQ bin level for the input prom
x = self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1)
# AR resp # AR resp
elif quant_levels is None or quant_levels == 0: elif quant_levels is None or quant_levels == 0:
x = self.embeddings[0]( xi if len(xi.shape) == 1 else xi[:, 0] ) x = self.embeddings[0]( xi if len(xi.shape) == 1 else xi[:, 0] )
# NAR resp # NAR resp
else: else:
if self.sums: x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
else:
k = xi.shape[-1] - 1 # only use the previous RVQ bin level for the current resp embedding
x = self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1)
return x return x
@ -165,16 +151,16 @@ class AudioEmbedding(nn.Module):
if quant_level is None: if quant_level is None:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
# embeddings for AR/NAR cannot be shared # jank but needed
offset = 0 if self.mode == "prom" or quant_level == 0 else 1
if xi.dim() == 1: if xi.dim() == 1:
x = self.embeddings[quant_level]( xi ) return self.embeddings[quant_level]( xi )
elif self.sums and quant_level > 0:
offset = 0 if self.mode == "prom" else 1
if self.sums and quant_level > 0:
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] ) x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
else: else:
k = quant_level k = quant_level - 1
x = self.embeddings[k + offset]( xi[:, k] ) x = self.embeddings[k + offset]( xi if xi.dim() == 1 else xi[:, k] )
return x return x
@ -302,15 +288,11 @@ class Base(nn.Module):
self.proms_emb = AudioEmbedding_Old( self.proms_emb = AudioEmbedding_Old(
[n_prom_tokens] * self.n_prom_levels, d_model, [n_prom_tokens] * self.n_prom_levels, d_model,
levels=self.n_prom_levels if self.version > 3 else None, levels=self.n_prom_levels if self.version > 3 else None,
mode="prom" if self.version >= 5 else "old",
sums=self.config.audio_embedding_sums if self.config is not None else True,
) )
# [1024 + STOP] + [1024] * 8 # [1024 + STOP] + [1024] * 8
self.resps_emb = AudioEmbedding_Old( self.resps_emb = AudioEmbedding_Old(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
levels=self.n_resp_levels if self.version > 3 else None, levels=self.n_resp_levels if self.version > 3 else None,
mode="resp" if self.version >= 5 else "old",
sums=self.config.audio_embedding_sums if self.config is not None else True
) )
else: else:
self.proms_emb = AudioEmbedding( self.proms_emb = AudioEmbedding(
@ -686,7 +668,7 @@ class Base(nn.Module):
elif name == "lang" and self.langs_emb is not None: elif name == "lang" and self.langs_emb is not None:
embedding = self.langs_emb( input ) embedding = self.langs_emb( input )
elif name == "prom": elif name == "prom":
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level] ) embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level if self.version >= 5 else None )
elif name == "tone" and self.tones_emb is not None: elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input ) embedding = self.tones_emb( input )
elif name == "resp": elif name == "resp":