cleaned up subjugated audio embedding into a flag, flag can also have it include the original, underlying embedding as well (it seems to do better when set to inclusive)

This commit is contained in:
mrq 2024-06-29 21:46:35 -05:00
parent ec5eaebcbc
commit 2808f881c8
5 changed files with 71 additions and 43 deletions

View File

@ -220,7 +220,7 @@ class Model:
capabilities: list = field(default_factory=lambda: ["ar", "nar"]) capabilities: list = field(default_factory=lambda: ["ar", "nar"])
experimental: str | None = None # for now it sets things to be HF compatible experimental: str | None = None # for now it sets things to be HF compatible
kv_heads: int = 0 # MHA or GQA (for supported backends) kv_heads: int = 0 # MHA or GQA (for supported backends)
use_external_audio_embeddings: bool = False # subjugates the audio backend's encoding/decoding model for embeddings audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range

View File

@ -263,12 +263,13 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"):
def _replace_file_extension(path, suffix): def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix) return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
# some experimental shit involving using the Encodec/DAC model's embeddings itself # an experimental way to include "trained" embeddings from the audio backend itself
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
@torch.inference_mode() @torch.inference_mode()
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"): def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
model = _load_model(device) model = _load_model(device)
codes = codes.to(device=device, dtype=torch.int32) codes = codes.to(device=device, dtype=torch.int32)
if codes.dim() == 1: if codes.dim() == 1:
@ -277,7 +278,7 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
codes = codes[:, quant_level] codes = codes[:, quant_level]
codes = rearrange(codes, "t -> 1 t") codes = rearrange(codes, "t -> 1 t")
# dac conveniently has its dim = 1024
if cfg.audio_backend == "dac": if cfg.audio_backend == "dac":
emb = model.quantizer.quantizers[quant_level] emb = model.quantizer.quantizers[quant_level]
@ -287,6 +288,15 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
return x return x
"""
# vocos inconveniently has its dim = 128
elif cfg.audio_backend == "vocos":
x = model.codes_to_features(codes)
# encodec inconveniently has its dim = 300
elif cfg.audio_backend == "encodec":
...
"""
raise Exception(f'Currently only DAC is supported') raise Exception(f'Currently only DAC is supported')
@torch.inference_mode() @torch.inference_mode()

View File

@ -107,10 +107,10 @@ class AR_NAR(Base):
return True return True
@property @property
def use_external_audio_embeddings(self) -> bool: def audio_embeddings_mode(self) -> bool:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
return self.config.use_external_audio_embeddings return self.config.audio_embeddings_mode
return cfg.model.use_external_audio_embeddings return cfg.model.audio_embeddings_mode
@property @property
def version(self) -> int: def version(self) -> int:
@ -473,7 +473,7 @@ def example_usage():
""" """
model = AR_NAR(**kwargs).to(device) model = AR_NAR(**kwargs).to(device)
steps = 200 if cfg.model.arch_type in ["mamba","mamba2"] else 200 steps = 100
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""

View File

@ -149,7 +149,8 @@ class AudioEmbedding(nn.Module):
self, self,
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token) l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding token_dim: int, # dimensionality of the embedding
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better) sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
external_mode: str | None = None, # "exclusive" | "inclusive", whether to include the original audio backend's embeddings
): ):
super().__init__() super().__init__()
# array of embeddings # array of embeddings
@ -157,10 +158,42 @@ class AudioEmbedding(nn.Module):
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR # resp are split to where [0] is for the AR, and [1:] are reserved for NAR
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level # + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
# # further experimentation is needed to see if this actually is useful
self.sums = sums self.sums = sums
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor: self.external_mode = external_mode
# set initial weights to zero
if self.external_mode == "inclusive":
for i, embedding in enumerate(self.embeddings):
embedding.weight = torch.nn.Parameter(torch.zeros( embedding.weight.shape ))
def external_embeddings(self, input: Tensor) -> Tensor:
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
# for AR, trim any stop tokens
has_stop_token = False
if quant_level == 0:
stop_token = self.embeddings[0].weight.shape[0] - 1
stop_token_indices = (input == stop_token).nonzero()
has_stop_token = len(stop_token_indices) > 0
if has_stop_token:
input = input[:stop_token_indices.min().item()]
# get external embedding
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)
# resize if necessary (in case the external embeddings do not match our model dim)
embedding = ml.resize_weight( embedding, self.embeddings[quant_level].weight.shape[-1], dim=-1, random=False )
# reintroduce stop token
if has_stop_token:
stop_token = self.internal_forward( torch.Tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 )
embedding = torch.concat( [ embedding, stop_token ] )
return embedding
def internal_forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
if self.sums and quant_level > 0: if self.sums and quant_level > 0:
@ -171,26 +204,17 @@ class AudioEmbedding(nn.Module):
return x return x
# subjugates the audio backend's embeddings def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
# inherits for use of the stop token x = self.internal_forward( xi, offset ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
class AudioEmbedding_External(AudioEmbedding):
def forward(self, input: Tensor, offset: int = 0 ) -> Tensor:
if not input.shape[0]:
return super().forward( input )
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1 if self.external_mode and xi.shape[0] > 0:
has_stop_token = quant_level == 0 and input[-1] == 1024 external_embeddings = self.external_embeddings( xi )
if self.external_mode == "exclusive":
return external_embeddings
x += external_embeddings
if has_stop_token: return x
input = input[:-1]
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[0].weight.dtype)
if has_stop_token:
stop_token = super().forward( torch.Tensor([1024]).to(device=input.device, dtype=torch.int16), 0 )
embedding = torch.concat( [ embedding, stop_token ] )
return embedding
# per-level classification # per-level classification
class AudioClassifier(nn.Module): class AudioClassifier(nn.Module):
def __init__( def __init__(
@ -292,8 +316,8 @@ class Base(nn.Module):
return False return False
@property @property
def use_external_audio_embeddings(self) -> bool: def audio_embeddings_mode(self) -> str | None:
return False return None
@property @property
def version(self) -> int: def version(self) -> int:
@ -392,23 +416,16 @@ class Base(nn.Module):
l_tokens, d_model, l_tokens, d_model,
levels=self.n_resp_levels if self.version > 3 else None, levels=self.n_resp_levels if self.version > 3 else None,
) )
elif self.use_external_audio_embeddings:
self.proms_emb = AudioEmbedding_External(
[n_prom_tokens] * self.n_prom_levels, d_model,
sums=audio_embedding_sums,
)
self.resps_emb = AudioEmbedding_External(
l_tokens, d_model,
sums=audio_embedding_sums,
)
else: else:
self.proms_emb = AudioEmbedding( self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model, [n_prom_tokens] * self.n_prom_levels, d_model,
sums=audio_embedding_sums, sums=audio_embedding_sums,
external_mode=self.audio_embeddings_mode,
) )
self.resps_emb = AudioEmbedding( self.resps_emb = AudioEmbedding(
l_tokens, d_model, l_tokens, d_model,
sums=audio_embedding_sums, sums=audio_embedding_sums,
external_mode=self.audio_embeddings_mode,
) )
# useless since I actually removed using these with the input processing overhaul... # useless since I actually removed using these with the input processing overhaul...

View File

@ -213,15 +213,16 @@ def replace_attention( model, klass, target, mode="math", verbose=False ):
return model return model
# trim/expand a tensor (for example, in a state dict) # trim/expand a tensor (for example, in a state dict)
def resize_weight( weight, target ): def resize_weight( weight, target, dim=0, random=True ):
# trim # trim
if target < weight.shape[0]: if target < weight.shape[dim]:
return weight[:target] return weight[:target]
# expand # expand
if target > weight.shape[0]: if target > weight.shape[dim]:
fn = torch.rand if random else torch.zeros
return torch.stack( return torch.stack(
[ x for x in weight ] + [ x for x in weight ] +
[ torch.rand( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[0] ) ] [ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
) )
return weight return weight