comments for clarity

This commit is contained in:
mrq 2025-03-16 11:30:23 -05:00
parent c5475ebc91
commit 2dfef693c4
2 changed files with 23 additions and 11 deletions

View File

@ -15,6 +15,7 @@ In theory, RVQ codecs should work better, as "importance" is consolidated in lev
* The underlying model could technically derive this importance itself, as it does receive the entire signal.
* The glamor of `nvidia/audio-codec-44khz` might not be so glamorous as the codebooks might be too dense for a model to easily operate on efficiently, as well as the codec's encoder/decoder being ***slow*** on ROCm.
* in other words, DAC might be preferable as a 44KHz medium.
* this might simply be a problem that can be "worked out" with more training time, hopefully, just as the "low confidence of higher codebook level" problem eventually works itself out.
## `AudioEncoder` / `AudioDecoder`

View File

@ -124,43 +124,53 @@ def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ):
def _interleave_sequence_flatten( input: list[torch.Tensor] ):
return torch.concat( [ i.t() for i in input ] ).t().flatten()
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
# Embedding that sums each codebook level within a given input acoustic prompt
# Mostly to handle some oversights and errors during testing
class AudioEmbedding(nn.Module):
def __init__(
self,
l_embedding_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other codebook levels (I do not know which way is better)
l_embedding_names: list[str] = [], # names to map to indices
):
super().__init__()
# array of embeddings
# proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR (except [-1] for NAR-len if utilized)
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
# further experimentation is needed to see if this actually is useful
self.sums = sums
#
# index of name maps to its corresponding embedding in the list
self.names = l_embedding_names
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, name: str | None = None, sums = None ) -> Tensor:
def forward(
self,
xi: Tensor, # input tensor
offset: int | None = None, # explicit offset, interop for the older codebase. use `name` instead
quant_level: int | None = None, # the codebook level of the audio we currently have (our `input_quant_level`)
name: str | None = None, # specifies where in the embeddings list to start from and iterate through
sums = None
) -> Tensor:
# if not explicitly requested, use the default setting at instantiation time
if sums is None:
sums = self.sums
# if not explicitly requested, assume input quant_level based on shape
if quant_level is None:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
# handle mapping from name
# handle mapping embedding index offset
if name in self.names:
offset = self.names.index( name )
offset -= quant_level # offset by quant level since it'll iterate up that many levels
offset -= quant_level # offset by quant_level since it'll iterate up that many levels
# sum all prior codebook levels if requested (as quant_level = 0 does not have any other codebooks to sum through)
if sums and quant_level > 0:
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
x = sum( [ self.embeddings[input_quant_level + offset]( xi[:, input_quant_level] ) for input_quant_level in range( quant_level ) ] )
else:
k = quant_level
x = self.embeddings[k + offset]( xi if xi.dim() == 1 else xi[:, k] )
input_quant_level = quant_level
x = self.embeddings[input_quant_level + offset]( xi if xi.dim() == 1 else xi[:, input_quant_level] )
return x
@ -880,11 +890,12 @@ class Base(nn.Module):
quant_level
)
else:
input_quant_level = 0 if quant_level == 0 else quant_level - 1 # input is one below the target quant level
embedding = self.resps_emb(
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
#offset = 0 if classifier_level.startswith("AR:") else 1,
name = classifier_level,
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
quant_level = input_quant_level,
)
# apply token dropout