comments for clarity
This commit is contained in:
parent
c5475ebc91
commit
2dfef693c4
|
@ -15,6 +15,7 @@ In theory, RVQ codecs should work better, as "importance" is consolidated in lev
|
|||
* The underlying model could technically derive this importance itself, as it does receive the entire signal.
|
||||
* The glamor of `nvidia/audio-codec-44khz` might not be so glamorous as the codebooks might be too dense for a model to easily operate on efficiently, as well as the codec's encoder/decoder being ***slow*** on ROCm.
|
||||
* in other words, DAC might be preferable as a 44KHz medium.
|
||||
* this might simply be a problem that can be "worked out" with more training time, hopefully, just as the "low confidence of higher codebook level" problem eventually works itself out.
|
||||
|
||||
## `AudioEncoder` / `AudioDecoder`
|
||||
|
||||
|
|
|
@ -124,43 +124,53 @@ def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ):
|
|||
def _interleave_sequence_flatten( input: list[torch.Tensor] ):
|
||||
return torch.concat( [ i.t() for i in input ] ).t().flatten()
|
||||
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
# Embedding that sums each codebook level within a given input acoustic prompt
|
||||
# Mostly to handle some oversights and errors during testing
|
||||
class AudioEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
l_embedding_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other codebook levels (I do not know which way is better)
|
||||
l_embedding_names: list[str] = [], # names to map to indices
|
||||
):
|
||||
super().__init__()
|
||||
# array of embeddings
|
||||
# proms are [0, resp_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR (except [-1] for NAR-len if utilized)
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
||||
# further experimentation is needed to see if this actually is useful
|
||||
self.sums = sums
|
||||
#
|
||||
# index of name maps to its corresponding embedding in the list
|
||||
self.names = l_embedding_names
|
||||
|
||||
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, name: str | None = None, sums = None ) -> Tensor:
|
||||
def forward(
|
||||
self,
|
||||
xi: Tensor, # input tensor
|
||||
offset: int | None = None, # explicit offset, interop for the older codebase. use `name` instead
|
||||
quant_level: int | None = None, # the codebook level of the audio we currently have (our `input_quant_level`)
|
||||
name: str | None = None, # specifies where in the embeddings list to start from and iterate through
|
||||
sums = None
|
||||
) -> Tensor:
|
||||
# if not explicitly requested, use the default setting at instantiation time
|
||||
if sums is None:
|
||||
sums = self.sums
|
||||
|
||||
# if not explicitly requested, assume input quant_level based on shape
|
||||
if quant_level is None:
|
||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||
|
||||
# handle mapping from name
|
||||
# handle mapping embedding index offset
|
||||
if name in self.names:
|
||||
offset = self.names.index( name )
|
||||
offset -= quant_level # offset by quant level since it'll iterate up that many levels
|
||||
offset -= quant_level # offset by quant_level since it'll iterate up that many levels
|
||||
|
||||
# sum all prior codebook levels if requested (as quant_level = 0 does not have any other codebooks to sum through)
|
||||
if sums and quant_level > 0:
|
||||
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
|
||||
x = sum( [ self.embeddings[input_quant_level + offset]( xi[:, input_quant_level] ) for input_quant_level in range( quant_level ) ] )
|
||||
else:
|
||||
k = quant_level
|
||||
x = self.embeddings[k + offset]( xi if xi.dim() == 1 else xi[:, k] )
|
||||
input_quant_level = quant_level
|
||||
x = self.embeddings[input_quant_level + offset]( xi if xi.dim() == 1 else xi[:, input_quant_level] )
|
||||
|
||||
return x
|
||||
|
||||
|
@ -880,11 +890,12 @@ class Base(nn.Module):
|
|||
quant_level
|
||||
)
|
||||
else:
|
||||
input_quant_level = 0 if quant_level == 0 else quant_level - 1 # input is one below the target quant level
|
||||
embedding = self.resps_emb(
|
||||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||
#offset = 0 if classifier_level.startswith("AR:") else 1,
|
||||
name = classifier_level,
|
||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||
quant_level = input_quant_level,
|
||||
)
|
||||
|
||||
# apply token dropout
|
||||
|
|
Loading…
Reference in New Issue
Block a user