From ec5eaebcbcdb9c22dccb8925d46513fe7efa52c8 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 29 Jun 2024 19:46:11 -0500 Subject: [PATCH] experimental method of using DACs quantizer ""embeddings"" to see if it helps with model quality --- vall_e/config.py | 1 + vall_e/emb/qnt.py | 26 +++++++++++++++++++++++++- vall_e/models/ar_nar.py | 21 ++++++++++++++++++--- vall_e/models/base.py | 36 ++++++++++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 4 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index b870280..e1471eb 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -220,6 +220,7 @@ class Model: capabilities: list = field(default_factory=lambda: ["ar", "nar"]) experimental: str | None = None # for now it sets things to be HF compatible kv_heads: int = 0 # MHA or GQA (for supported backends) + use_external_audio_embeddings: bool = False # subjugates the audio backend's encoding/decoding model for embeddings p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 77e0375..36a1f1e 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -182,7 +182,6 @@ def unload_model(): _load_model.cache_clear() _load_encodec_model.cache_clear() # because vocos can only decode - @torch.inference_mode() def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=None): # upcast so it won't whine @@ -264,6 +263,31 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"): def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) +# some experimental shit involving using the Encodec/DAC model's embeddings itself +@torch.inference_mode() +def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"): + model = _load_model(device) + + + codes = codes.to(device=device, dtype=torch.int32) + + if codes.dim() == 1: + codes = rearrange(codes, "t -> 1 t") + else: + codes = codes[:, quant_level] + codes = rearrange(codes, "t -> 1 t") + + + if cfg.audio_backend == "dac": + emb = model.quantizer.quantizers[quant_level] + + x = emb.decode_code(codes) + x = emb.out_proj(x) + x = x[0].t().detach() + + return x + + raise Exception(f'Currently only DAC is supported') @torch.inference_mode() def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True): diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index e8341a9..ef6a808 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -18,7 +18,7 @@ from einops import rearrange from torch import Tensor from tqdm import trange -from ..emb.qnt import trim +from ..emb.qnt import trim, encode_as_embedding from .lora import enable_lora @@ -106,6 +106,12 @@ class AR_NAR(Base): def monolithic(self) -> bool: return True + @property + def use_external_audio_embeddings(self) -> bool: + if hasattr(self, "config") and self.config: + return self.config.use_external_audio_embeddings + return cfg.model.use_external_audio_embeddings + @property def version(self) -> int: if hasattr(self, "config") and self.config: @@ -191,14 +197,23 @@ class AR_NAR(Base): resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] - # append stop tokens for AR - # could technically do it in the .inputs call for i in range(batch_size): + # cap quant_level if it exceeds its corresponding resp/prom + if quant_levels[i] >= resps_list[i].shape[-1]: + quant_levels[i] = resps_list[i].shape[-1] - 1 + + if quant_levels[i] >= proms_list[i].shape[-1]: + quant_levels[i] = proms_list[i].shape[-1] - 1 + # only apply stop token for RVQ level 0 if quant_levels[i] > 0: continue + + # append stop tokens for AR + # could technically do it in the .inputs call resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ]) + inputs = self.inputs( text_list=text_list, proms_list=proms_list, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index f2658b2..1eb175c 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -31,6 +31,8 @@ from .arch import * from ..utils import wrapper as ml from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample +from ..emb.qnt import encode_as_embedding + def _create_mask(l, device): """1 is valid region and 0 is invalid.""" seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) @@ -169,6 +171,26 @@ class AudioEmbedding(nn.Module): return x +# subjugates the audio backend's embeddings +# inherits for use of the stop token +class AudioEmbedding_External(AudioEmbedding): + def forward(self, input: Tensor, offset: int = 0 ) -> Tensor: + if not input.shape[0]: + return super().forward( input ) + + quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1 + has_stop_token = quant_level == 0 and input[-1] == 1024 + + if has_stop_token: + input = input[:-1] + + embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[0].weight.dtype) + + if has_stop_token: + stop_token = super().forward( torch.Tensor([1024]).to(device=input.device, dtype=torch.int16), 0 ) + embedding = torch.concat( [ embedding, stop_token ] ) + + return embedding # per-level classification class AudioClassifier(nn.Module): def __init__( @@ -269,6 +291,10 @@ class Base(nn.Module): def monolithic(self) -> bool: return False + @property + def use_external_audio_embeddings(self) -> bool: + return False + @property def version(self) -> int: return 1 @@ -366,6 +392,15 @@ class Base(nn.Module): l_tokens, d_model, levels=self.n_resp_levels if self.version > 3 else None, ) + elif self.use_external_audio_embeddings: + self.proms_emb = AudioEmbedding_External( + [n_prom_tokens] * self.n_prom_levels, d_model, + sums=audio_embedding_sums, + ) + self.resps_emb = AudioEmbedding_External( + l_tokens, d_model, + sums=audio_embedding_sums, + ) else: self.proms_emb = AudioEmbedding( [n_prom_tokens] * self.n_prom_levels, d_model, @@ -835,6 +870,7 @@ class Base(nn.Module): # technically can provide a map for input_name => embedding, but some embedding requires additional processing embedding = None + # is already an embedding if name == "task": # noop # *maybe* inject a token for specifying task type