god it would be nice to know the best way to handle audio embeddings, because I genuinely don't know without skimming through papers or devoting X amount of GPU hours in training
This commit is contained in:
parent
6a11bc9cb6
commit
5120ffdda7
|
@ -190,6 +190,7 @@ class Model:
|
|||
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
|
||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||
attention: str = "eager" # or flash_attention_2
|
||||
audio_embedding_sums: bool = True
|
||||
|
||||
def get(self, name=None):
|
||||
return [ self ] if not name or self.name == name else []
|
||||
|
|
|
@ -174,11 +174,8 @@ class Embedding(nn.Embedding):
|
|||
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
|
||||
"""
|
||||
|
||||
# Deprecated implementation
|
||||
class MultiEmbedding(nn.Module):
|
||||
"""
|
||||
This embedding sums embeddings on different levels.
|
||||
"""
|
||||
|
||||
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
|
||||
super().__init__()
|
||||
self.monolithic = monolithic
|
||||
|
@ -216,21 +213,41 @@ class MultiEmbedding(nn.Module):
|
|||
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
class AudioEmbedding(nn.Module):
|
||||
def __init__(self, l_tokens, token_dim, levels=None):
|
||||
def __init__(
|
||||
self,
|
||||
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
|
||||
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||
):
|
||||
super().__init__()
|
||||
# array of embeddings
|
||||
# proms are [0, prom_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||
#
|
||||
self.sums = sums
|
||||
|
||||
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
|
||||
# prom
|
||||
if quant_levels is None and xi.shape[-1] > 1:
|
||||
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||
if self.sums:
|
||||
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||
else:
|
||||
k = 0 # only use the most significant RVQ bin level for the input prom
|
||||
x = self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1)
|
||||
# AR resp
|
||||
elif quant_levels is None or quant_levels == 0:
|
||||
x = self.embeddings[0]( xi[:, 0] )
|
||||
# NAR resp
|
||||
else:
|
||||
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||
if self.sums:
|
||||
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||
else:
|
||||
k = xi.shape[-1] - 1 # only use the previous RVQ bin level for the current resp embedding
|
||||
x = self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1)
|
||||
|
||||
return x
|
||||
|
||||
|
@ -345,9 +362,17 @@ class Base(nn.Module):
|
|||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||
else:
|
||||
# [1024] * 8
|
||||
self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model, self.n_prom_levels if self.version > 3 else None)
|
||||
self.proms_emb = AudioEmbedding(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
levels=self.n_prom_levels if self.version > 3 else None,
|
||||
sums=self.config.audio_embedding_sums
|
||||
)
|
||||
# [1025] + [1024] * 8
|
||||
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, self.n_resp_levels if self.version > 3 else None)
|
||||
self.resps_emb = AudioEmbedding(
|
||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
sums=self.config.audio_embedding_sums
|
||||
)
|
||||
|
||||
|
||||
if self.version >= 3:
|
||||
|
|
Loading…
Reference in New Issue
Block a user