cleaned up subjugated audio embedding into a flag, flag can also have it include the original, underlying embedding as well (it seems to do better when set to inclusive)
This commit is contained in:
parent
ec5eaebcbc
commit
2808f881c8
|
@ -220,7 +220,7 @@ class Model:
|
|||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
||||
experimental: str | None = None # for now it sets things to be HF compatible
|
||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||
use_external_audio_embeddings: bool = False # subjugates the audio backend's encoding/decoding model for embeddings
|
||||
audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
||||
|
||||
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
||||
|
|
|
@ -263,12 +263,13 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"):
|
|||
def _replace_file_extension(path, suffix):
|
||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
# some experimental shit involving using the Encodec/DAC model's embeddings itself
|
||||
# an experimental way to include "trained" embeddings from the audio backend itself
|
||||
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
|
||||
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
|
||||
@torch.inference_mode()
|
||||
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
|
||||
model = _load_model(device)
|
||||
|
||||
|
||||
codes = codes.to(device=device, dtype=torch.int32)
|
||||
|
||||
if codes.dim() == 1:
|
||||
|
@ -277,7 +278,7 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
|
|||
codes = codes[:, quant_level]
|
||||
codes = rearrange(codes, "t -> 1 t")
|
||||
|
||||
|
||||
# dac conveniently has its dim = 1024
|
||||
if cfg.audio_backend == "dac":
|
||||
emb = model.quantizer.quantizers[quant_level]
|
||||
|
||||
|
@ -287,6 +288,15 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
|
|||
|
||||
return x
|
||||
|
||||
"""
|
||||
# vocos inconveniently has its dim = 128
|
||||
elif cfg.audio_backend == "vocos":
|
||||
x = model.codes_to_features(codes)
|
||||
# encodec inconveniently has its dim = 300
|
||||
elif cfg.audio_backend == "encodec":
|
||||
...
|
||||
"""
|
||||
|
||||
raise Exception(f'Currently only DAC is supported')
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
|
@ -107,10 +107,10 @@ class AR_NAR(Base):
|
|||
return True
|
||||
|
||||
@property
|
||||
def use_external_audio_embeddings(self) -> bool:
|
||||
def audio_embeddings_mode(self) -> bool:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.use_external_audio_embeddings
|
||||
return cfg.model.use_external_audio_embeddings
|
||||
return self.config.audio_embeddings_mode
|
||||
return cfg.model.audio_embeddings_mode
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
|
@ -473,7 +473,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 200 if cfg.model.arch_type in ["mamba","mamba2"] else 200
|
||||
steps = 100
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
|
|
@ -149,7 +149,8 @@ class AudioEmbedding(nn.Module):
|
|||
self,
|
||||
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||
external_mode: str | None = None, # "exclusive" | "inclusive", whether to include the original audio backend's embeddings
|
||||
):
|
||||
super().__init__()
|
||||
# array of embeddings
|
||||
|
@ -157,10 +158,42 @@ class AudioEmbedding(nn.Module):
|
|||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||
#
|
||||
# further experimentation is needed to see if this actually is useful
|
||||
self.sums = sums
|
||||
|
||||
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
|
||||
self.external_mode = external_mode
|
||||
|
||||
# set initial weights to zero
|
||||
if self.external_mode == "inclusive":
|
||||
for i, embedding in enumerate(self.embeddings):
|
||||
embedding.weight = torch.nn.Parameter(torch.zeros( embedding.weight.shape ))
|
||||
|
||||
def external_embeddings(self, input: Tensor) -> Tensor:
|
||||
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
|
||||
|
||||
# for AR, trim any stop tokens
|
||||
has_stop_token = False
|
||||
if quant_level == 0:
|
||||
stop_token = self.embeddings[0].weight.shape[0] - 1
|
||||
stop_token_indices = (input == stop_token).nonzero()
|
||||
has_stop_token = len(stop_token_indices) > 0
|
||||
|
||||
if has_stop_token:
|
||||
input = input[:stop_token_indices.min().item()]
|
||||
|
||||
# get external embedding
|
||||
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)
|
||||
# resize if necessary (in case the external embeddings do not match our model dim)
|
||||
embedding = ml.resize_weight( embedding, self.embeddings[quant_level].weight.shape[-1], dim=-1, random=False )
|
||||
|
||||
# reintroduce stop token
|
||||
if has_stop_token:
|
||||
stop_token = self.internal_forward( torch.Tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 )
|
||||
embedding = torch.concat( [ embedding, stop_token ] )
|
||||
|
||||
return embedding
|
||||
|
||||
def internal_forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
|
||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||
|
||||
if self.sums and quant_level > 0:
|
||||
|
@ -171,26 +204,17 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
# subjugates the audio backend's embeddings
|
||||
# inherits for use of the stop token
|
||||
class AudioEmbedding_External(AudioEmbedding):
|
||||
def forward(self, input: Tensor, offset: int = 0 ) -> Tensor:
|
||||
if not input.shape[0]:
|
||||
return super().forward( input )
|
||||
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
|
||||
x = self.internal_forward( xi, offset ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
|
||||
|
||||
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
|
||||
has_stop_token = quant_level == 0 and input[-1] == 1024
|
||||
if self.external_mode and xi.shape[0] > 0:
|
||||
external_embeddings = self.external_embeddings( xi )
|
||||
if self.external_mode == "exclusive":
|
||||
return external_embeddings
|
||||
x += external_embeddings
|
||||
|
||||
if has_stop_token:
|
||||
input = input[:-1]
|
||||
return x
|
||||
|
||||
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[0].weight.dtype)
|
||||
|
||||
if has_stop_token:
|
||||
stop_token = super().forward( torch.Tensor([1024]).to(device=input.device, dtype=torch.int16), 0 )
|
||||
embedding = torch.concat( [ embedding, stop_token ] )
|
||||
|
||||
return embedding
|
||||
# per-level classification
|
||||
class AudioClassifier(nn.Module):
|
||||
def __init__(
|
||||
|
@ -292,8 +316,8 @@ class Base(nn.Module):
|
|||
return False
|
||||
|
||||
@property
|
||||
def use_external_audio_embeddings(self) -> bool:
|
||||
return False
|
||||
def audio_embeddings_mode(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
|
@ -392,23 +416,16 @@ class Base(nn.Module):
|
|||
l_tokens, d_model,
|
||||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
)
|
||||
elif self.use_external_audio_embeddings:
|
||||
self.proms_emb = AudioEmbedding_External(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
)
|
||||
self.resps_emb = AudioEmbedding_External(
|
||||
l_tokens, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
)
|
||||
else:
|
||||
self.proms_emb = AudioEmbedding(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
external_mode=self.audio_embeddings_mode,
|
||||
)
|
||||
self.resps_emb = AudioEmbedding(
|
||||
l_tokens, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
external_mode=self.audio_embeddings_mode,
|
||||
)
|
||||
|
||||
# useless since I actually removed using these with the input processing overhaul...
|
||||
|
|
|
@ -213,15 +213,16 @@ def replace_attention( model, klass, target, mode="math", verbose=False ):
|
|||
return model
|
||||
|
||||
# trim/expand a tensor (for example, in a state dict)
|
||||
def resize_weight( weight, target ):
|
||||
def resize_weight( weight, target, dim=0, random=True ):
|
||||
# trim
|
||||
if target < weight.shape[0]:
|
||||
if target < weight.shape[dim]:
|
||||
return weight[:target]
|
||||
# expand
|
||||
if target > weight.shape[0]:
|
||||
if target > weight.shape[dim]:
|
||||
fn = torch.rand if random else torch.zeros
|
||||
return torch.stack(
|
||||
[ x for x in weight ] +
|
||||
[ torch.rand( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[0] ) ]
|
||||
[ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
|
||||
)
|
||||
|
||||
return weight
|
||||
|
|
Loading…
Reference in New Issue
Block a user