cleaned up subjugated audio embedding into a flag, flag can also have it include the original, underlying embedding as well (it seems to do better when set to inclusive)

This commit is contained in:
mrq 2024-06-29 21:46:35 -05:00
parent ec5eaebcbc
commit 2808f881c8
5 changed files with 71 additions and 43 deletions

View File

@ -220,7 +220,7 @@ class Model:
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
experimental: str | None = None # for now it sets things to be HF compatible
kv_heads: int = 0 # MHA or GQA (for supported backends)
use_external_audio_embeddings: bool = False # subjugates the audio backend's encoding/decoding model for embeddings
audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range

View File

@ -263,12 +263,13 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"):
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
# some experimental shit involving using the Encodec/DAC model's embeddings itself
# an experimental way to include "trained" embeddings from the audio backend itself
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
@torch.inference_mode()
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
model = _load_model(device)
codes = codes.to(device=device, dtype=torch.int32)
if codes.dim() == 1:
@ -277,7 +278,7 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
codes = codes[:, quant_level]
codes = rearrange(codes, "t -> 1 t")
# dac conveniently has its dim = 1024
if cfg.audio_backend == "dac":
emb = model.quantizer.quantizers[quant_level]
@ -287,6 +288,15 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
return x
"""
# vocos inconveniently has its dim = 128
elif cfg.audio_backend == "vocos":
x = model.codes_to_features(codes)
# encodec inconveniently has its dim = 300
elif cfg.audio_backend == "encodec":
...
"""
raise Exception(f'Currently only DAC is supported')
@torch.inference_mode()

View File

@ -107,10 +107,10 @@ class AR_NAR(Base):
return True
@property
def use_external_audio_embeddings(self) -> bool:
def audio_embeddings_mode(self) -> bool:
if hasattr(self, "config") and self.config:
return self.config.use_external_audio_embeddings
return cfg.model.use_external_audio_embeddings
return self.config.audio_embeddings_mode
return cfg.model.audio_embeddings_mode
@property
def version(self) -> int:
@ -473,7 +473,7 @@ def example_usage():
"""
model = AR_NAR(**kwargs).to(device)
steps = 200 if cfg.model.arch_type in ["mamba","mamba2"] else 200
steps = 100
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""

View File

@ -149,7 +149,8 @@ class AudioEmbedding(nn.Module):
self,
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
external_mode: str | None = None, # "exclusive" | "inclusive", whether to include the original audio backend's embeddings
):
super().__init__()
# array of embeddings
@ -157,10 +158,42 @@ class AudioEmbedding(nn.Module):
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
#
# further experimentation is needed to see if this actually is useful
self.sums = sums
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
self.external_mode = external_mode
# set initial weights to zero
if self.external_mode == "inclusive":
for i, embedding in enumerate(self.embeddings):
embedding.weight = torch.nn.Parameter(torch.zeros( embedding.weight.shape ))
def external_embeddings(self, input: Tensor) -> Tensor:
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
# for AR, trim any stop tokens
has_stop_token = False
if quant_level == 0:
stop_token = self.embeddings[0].weight.shape[0] - 1
stop_token_indices = (input == stop_token).nonzero()
has_stop_token = len(stop_token_indices) > 0
if has_stop_token:
input = input[:stop_token_indices.min().item()]
# get external embedding
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)
# resize if necessary (in case the external embeddings do not match our model dim)
embedding = ml.resize_weight( embedding, self.embeddings[quant_level].weight.shape[-1], dim=-1, random=False )
# reintroduce stop token
if has_stop_token:
stop_token = self.internal_forward( torch.Tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 )
embedding = torch.concat( [ embedding, stop_token ] )
return embedding
def internal_forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
if self.sums and quant_level > 0:
@ -171,26 +204,17 @@ class AudioEmbedding(nn.Module):
return x
# subjugates the audio backend's embeddings
# inherits for use of the stop token
class AudioEmbedding_External(AudioEmbedding):
def forward(self, input: Tensor, offset: int = 0 ) -> Tensor:
if not input.shape[0]:
return super().forward( input )
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
x = self.internal_forward( xi, offset ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
has_stop_token = quant_level == 0 and input[-1] == 1024
if self.external_mode and xi.shape[0] > 0:
external_embeddings = self.external_embeddings( xi )
if self.external_mode == "exclusive":
return external_embeddings
x += external_embeddings
if has_stop_token:
input = input[:-1]
return x
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[0].weight.dtype)
if has_stop_token:
stop_token = super().forward( torch.Tensor([1024]).to(device=input.device, dtype=torch.int16), 0 )
embedding = torch.concat( [ embedding, stop_token ] )
return embedding
# per-level classification
class AudioClassifier(nn.Module):
def __init__(
@ -292,8 +316,8 @@ class Base(nn.Module):
return False
@property
def use_external_audio_embeddings(self) -> bool:
return False
def audio_embeddings_mode(self) -> str | None:
return None
@property
def version(self) -> int:
@ -392,23 +416,16 @@ class Base(nn.Module):
l_tokens, d_model,
levels=self.n_resp_levels if self.version > 3 else None,
)
elif self.use_external_audio_embeddings:
self.proms_emb = AudioEmbedding_External(
[n_prom_tokens] * self.n_prom_levels, d_model,
sums=audio_embedding_sums,
)
self.resps_emb = AudioEmbedding_External(
l_tokens, d_model,
sums=audio_embedding_sums,
)
else:
self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model,
sums=audio_embedding_sums,
external_mode=self.audio_embeddings_mode,
)
self.resps_emb = AudioEmbedding(
l_tokens, d_model,
sums=audio_embedding_sums,
external_mode=self.audio_embeddings_mode,
)
# useless since I actually removed using these with the input processing overhaul...

View File

@ -213,15 +213,16 @@ def replace_attention( model, klass, target, mode="math", verbose=False ):
return model
# trim/expand a tensor (for example, in a state dict)
def resize_weight( weight, target ):
def resize_weight( weight, target, dim=0, random=True ):
# trim
if target < weight.shape[0]:
if target < weight.shape[dim]:
return weight[:target]
# expand
if target > weight.shape[0]:
if target > weight.shape[dim]:
fn = torch.rand if random else torch.zeros
return torch.stack(
[ x for x in weight ] +
[ torch.rand( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[0] ) ]
[ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
)
return weight