god it would be nice to know the best way to handle audio embeddings, because I genuinely don't know without skimming through papers or devoting X amount of GPU hours in training

This commit is contained in:
mrq 2024-04-29 18:24:05 -05:00
parent 6a11bc9cb6
commit 5120ffdda7
2 changed files with 35 additions and 9 deletions

View File

@ -190,6 +190,7 @@ class Model:
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
attention: str = "eager" # or flash_attention_2
audio_embedding_sums: bool = True
def get(self, name=None):
return [ self ] if not name or self.name == name else []

View File

@ -174,11 +174,8 @@ class Embedding(nn.Embedding):
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
"""
# Deprecated implementation
class MultiEmbedding(nn.Module):
"""
This embedding sums embeddings on different levels.
"""
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
super().__init__()
self.monolithic = monolithic
@ -216,21 +213,41 @@ class MultiEmbedding(nn.Module):
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
class AudioEmbedding(nn.Module):
def __init__(self, l_tokens, token_dim, levels=None):
def __init__(
self,
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
):
super().__init__()
# array of embeddings
# proms are [0, prom_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
#
self.sums = sums
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
# prom
if quant_levels is None and xi.shape[-1] > 1:
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
if self.sums:
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
else:
k = 0 # only use the most significant RVQ bin level for the input prom
x = self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1)
# AR resp
elif quant_levels is None or quant_levels == 0:
x = self.embeddings[0]( xi[:, 0] )
# NAR resp
else:
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
if self.sums:
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
else:
k = xi.shape[-1] - 1 # only use the previous RVQ bin level for the current resp embedding
x = self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1)
return x
@ -345,9 +362,17 @@ class Base(nn.Module):
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
else:
# [1024] * 8
self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model, self.n_prom_levels if self.version > 3 else None)
self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model,
levels=self.n_prom_levels if self.version > 3 else None,
sums=self.config.audio_embedding_sums
)
# [1025] + [1024] * 8
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, self.n_resp_levels if self.version > 3 else None)
self.resps_emb = AudioEmbedding(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
levels=self.n_resp_levels if self.version > 3 else None,
sums=self.config.audio_embedding_sums
)
if self.version >= 3: