experimental method of using DACs quantizer ""embeddings"" to see if it helps with model quality

This commit is contained in:
mrq 2024-06-29 19:46:11 -05:00
parent a8718d35a4
commit ec5eaebcbc
4 changed files with 80 additions and 4 deletions

View File

@ -220,6 +220,7 @@ class Model:
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
experimental: str | None = None # for now it sets things to be HF compatible
kv_heads: int = 0 # MHA or GQA (for supported backends)
use_external_audio_embeddings: bool = False # subjugates the audio backend's encoding/decoding model for embeddings
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range

View File

@ -182,7 +182,6 @@ def unload_model():
_load_model.cache_clear()
_load_encodec_model.cache_clear() # because vocos can only decode
@torch.inference_mode()
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=None):
# upcast so it won't whine
@ -264,6 +263,31 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"):
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
# some experimental shit involving using the Encodec/DAC model's embeddings itself
@torch.inference_mode()
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
model = _load_model(device)
codes = codes.to(device=device, dtype=torch.int32)
if codes.dim() == 1:
codes = rearrange(codes, "t -> 1 t")
else:
codes = codes[:, quant_level]
codes = rearrange(codes, "t -> 1 t")
if cfg.audio_backend == "dac":
emb = model.quantizer.quantizers[quant_level]
x = emb.decode_code(codes)
x = emb.out_proj(x)
x = x[0].t().detach()
return x
raise Exception(f'Currently only DAC is supported')
@torch.inference_mode()
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):

View File

@ -18,7 +18,7 @@ from einops import rearrange
from torch import Tensor
from tqdm import trange
from ..emb.qnt import trim
from ..emb.qnt import trim, encode_as_embedding
from .lora import enable_lora
@ -106,6 +106,12 @@ class AR_NAR(Base):
def monolithic(self) -> bool:
return True
@property
def use_external_audio_embeddings(self) -> bool:
if hasattr(self, "config") and self.config:
return self.config.use_external_audio_embeddings
return cfg.model.use_external_audio_embeddings
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
@ -191,14 +197,23 @@ class AR_NAR(Base):
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
# append stop tokens for AR
# could technically do it in the .inputs call
for i in range(batch_size):
# cap quant_level if it exceeds its corresponding resp/prom
if quant_levels[i] >= resps_list[i].shape[-1]:
quant_levels[i] = resps_list[i].shape[-1] - 1
if quant_levels[i] >= proms_list[i].shape[-1]:
quant_levels[i] = proms_list[i].shape[-1] - 1
# only apply stop token for RVQ level 0
if quant_levels[i] > 0:
continue
# append stop tokens for AR
# could technically do it in the .inputs call
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,

View File

@ -31,6 +31,8 @@ from .arch import *
from ..utils import wrapper as ml
from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
from ..emb.qnt import encode_as_embedding
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -169,6 +171,26 @@ class AudioEmbedding(nn.Module):
return x
# subjugates the audio backend's embeddings
# inherits for use of the stop token
class AudioEmbedding_External(AudioEmbedding):
def forward(self, input: Tensor, offset: int = 0 ) -> Tensor:
if not input.shape[0]:
return super().forward( input )
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
has_stop_token = quant_level == 0 and input[-1] == 1024
if has_stop_token:
input = input[:-1]
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[0].weight.dtype)
if has_stop_token:
stop_token = super().forward( torch.Tensor([1024]).to(device=input.device, dtype=torch.int16), 0 )
embedding = torch.concat( [ embedding, stop_token ] )
return embedding
# per-level classification
class AudioClassifier(nn.Module):
def __init__(
@ -269,6 +291,10 @@ class Base(nn.Module):
def monolithic(self) -> bool:
return False
@property
def use_external_audio_embeddings(self) -> bool:
return False
@property
def version(self) -> int:
return 1
@ -366,6 +392,15 @@ class Base(nn.Module):
l_tokens, d_model,
levels=self.n_resp_levels if self.version > 3 else None,
)
elif self.use_external_audio_embeddings:
self.proms_emb = AudioEmbedding_External(
[n_prom_tokens] * self.n_prom_levels, d_model,
sums=audio_embedding_sums,
)
self.resps_emb = AudioEmbedding_External(
l_tokens, d_model,
sums=audio_embedding_sums,
)
else:
self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model,
@ -835,6 +870,7 @@ class Base(nn.Module):
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
embedding = None
# is already an embedding
if name == "task":
# noop
# *maybe* inject a token for specifying task type