comments for clarity
This commit is contained in:
parent
c5475ebc91
commit
2dfef693c4
|
@ -15,6 +15,7 @@ In theory, RVQ codecs should work better, as "importance" is consolidated in lev
|
||||||
* The underlying model could technically derive this importance itself, as it does receive the entire signal.
|
* The underlying model could technically derive this importance itself, as it does receive the entire signal.
|
||||||
* The glamor of `nvidia/audio-codec-44khz` might not be so glamorous as the codebooks might be too dense for a model to easily operate on efficiently, as well as the codec's encoder/decoder being ***slow*** on ROCm.
|
* The glamor of `nvidia/audio-codec-44khz` might not be so glamorous as the codebooks might be too dense for a model to easily operate on efficiently, as well as the codec's encoder/decoder being ***slow*** on ROCm.
|
||||||
* in other words, DAC might be preferable as a 44KHz medium.
|
* in other words, DAC might be preferable as a 44KHz medium.
|
||||||
|
* this might simply be a problem that can be "worked out" with more training time, hopefully, just as the "low confidence of higher codebook level" problem eventually works itself out.
|
||||||
|
|
||||||
## `AudioEncoder` / `AudioDecoder`
|
## `AudioEncoder` / `AudioDecoder`
|
||||||
|
|
||||||
|
|
|
@ -124,43 +124,53 @@ def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ):
|
||||||
def _interleave_sequence_flatten( input: list[torch.Tensor] ):
|
def _interleave_sequence_flatten( input: list[torch.Tensor] ):
|
||||||
return torch.concat( [ i.t() for i in input ] ).t().flatten()
|
return torch.concat( [ i.t() for i in input ] ).t().flatten()
|
||||||
|
|
||||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
# Embedding that sums each codebook level within a given input acoustic prompt
|
||||||
# Mostly to handle some oversights and errors during testing
|
# Mostly to handle some oversights and errors during testing
|
||||||
class AudioEmbedding(nn.Module):
|
class AudioEmbedding(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
l_embedding_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
l_embedding_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||||
token_dim: int, # dimensionality of the embedding
|
token_dim: int, # dimensionality of the embedding
|
||||||
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other codebook levels (I do not know which way is better)
|
||||||
l_embedding_names: list[str] = [], # names to map to indices
|
l_embedding_names: list[str] = [], # names to map to indices
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# array of embeddings
|
# array of embeddings
|
||||||
# proms are [0, resp_levels]
|
# proms are [0, resp_levels]
|
||||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR (except [-1] for NAR-len if utilized)
|
||||||
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
||||||
# further experimentation is needed to see if this actually is useful
|
# further experimentation is needed to see if this actually is useful
|
||||||
self.sums = sums
|
self.sums = sums
|
||||||
#
|
# index of name maps to its corresponding embedding in the list
|
||||||
self.names = l_embedding_names
|
self.names = l_embedding_names
|
||||||
|
|
||||||
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, name: str | None = None, sums = None ) -> Tensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
xi: Tensor, # input tensor
|
||||||
|
offset: int | None = None, # explicit offset, interop for the older codebase. use `name` instead
|
||||||
|
quant_level: int | None = None, # the codebook level of the audio we currently have (our `input_quant_level`)
|
||||||
|
name: str | None = None, # specifies where in the embeddings list to start from and iterate through
|
||||||
|
sums = None
|
||||||
|
) -> Tensor:
|
||||||
|
# if not explicitly requested, use the default setting at instantiation time
|
||||||
if sums is None:
|
if sums is None:
|
||||||
sums = self.sums
|
sums = self.sums
|
||||||
|
|
||||||
|
# if not explicitly requested, assume input quant_level based on shape
|
||||||
if quant_level is None:
|
if quant_level is None:
|
||||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||||
|
|
||||||
# handle mapping from name
|
# handle mapping embedding index offset
|
||||||
if name in self.names:
|
if name in self.names:
|
||||||
offset = self.names.index( name )
|
offset = self.names.index( name )
|
||||||
offset -= quant_level # offset by quant level since it'll iterate up that many levels
|
offset -= quant_level # offset by quant_level since it'll iterate up that many levels
|
||||||
|
|
||||||
|
# sum all prior codebook levels if requested (as quant_level = 0 does not have any other codebooks to sum through)
|
||||||
if sums and quant_level > 0:
|
if sums and quant_level > 0:
|
||||||
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
|
x = sum( [ self.embeddings[input_quant_level + offset]( xi[:, input_quant_level] ) for input_quant_level in range( quant_level ) ] )
|
||||||
else:
|
else:
|
||||||
k = quant_level
|
input_quant_level = quant_level
|
||||||
x = self.embeddings[k + offset]( xi if xi.dim() == 1 else xi[:, k] )
|
x = self.embeddings[input_quant_level + offset]( xi if xi.dim() == 1 else xi[:, input_quant_level] )
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -880,11 +890,12 @@ class Base(nn.Module):
|
||||||
quant_level
|
quant_level
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
input_quant_level = 0 if quant_level == 0 else quant_level - 1 # input is one below the target quant level
|
||||||
embedding = self.resps_emb(
|
embedding = self.resps_emb(
|
||||||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||||
#offset = 0 if classifier_level.startswith("AR:") else 1,
|
#offset = 0 if classifier_level.startswith("AR:") else 1,
|
||||||
name = classifier_level,
|
name = classifier_level,
|
||||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
quant_level = input_quant_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply token dropout
|
# apply token dropout
|
||||||
|
|
Loading…
Reference in New Issue
Block a user