experimental method of using DACs quantizer ""embeddings"" to see if it helps with model quality
This commit is contained in:
parent
a8718d35a4
commit
ec5eaebcbc
|
@ -220,6 +220,7 @@ class Model:
|
|||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
||||
experimental: str | None = None # for now it sets things to be HF compatible
|
||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||
use_external_audio_embeddings: bool = False # subjugates the audio backend's encoding/decoding model for embeddings
|
||||
|
||||
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
||||
|
|
|
@ -182,7 +182,6 @@ def unload_model():
|
|||
_load_model.cache_clear()
|
||||
_load_encodec_model.cache_clear() # because vocos can only decode
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=None):
|
||||
# upcast so it won't whine
|
||||
|
@ -264,6 +263,31 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"):
|
|||
def _replace_file_extension(path, suffix):
|
||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
# some experimental shit involving using the Encodec/DAC model's embeddings itself
|
||||
@torch.inference_mode()
|
||||
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
|
||||
model = _load_model(device)
|
||||
|
||||
|
||||
codes = codes.to(device=device, dtype=torch.int32)
|
||||
|
||||
if codes.dim() == 1:
|
||||
codes = rearrange(codes, "t -> 1 t")
|
||||
else:
|
||||
codes = codes[:, quant_level]
|
||||
codes = rearrange(codes, "t -> 1 t")
|
||||
|
||||
|
||||
if cfg.audio_backend == "dac":
|
||||
emb = model.quantizer.quantizers[quant_level]
|
||||
|
||||
x = emb.decode_code(codes)
|
||||
x = emb.out_proj(x)
|
||||
x = x[0].t().detach()
|
||||
|
||||
return x
|
||||
|
||||
raise Exception(f'Currently only DAC is supported')
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):
|
||||
|
|
|
@ -18,7 +18,7 @@ from einops import rearrange
|
|||
from torch import Tensor
|
||||
from tqdm import trange
|
||||
|
||||
from ..emb.qnt import trim
|
||||
from ..emb.qnt import trim, encode_as_embedding
|
||||
|
||||
from .lora import enable_lora
|
||||
|
||||
|
@ -106,6 +106,12 @@ class AR_NAR(Base):
|
|||
def monolithic(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def use_external_audio_embeddings(self) -> bool:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.use_external_audio_embeddings
|
||||
return cfg.model.use_external_audio_embeddings
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
|
@ -191,14 +197,23 @@ class AR_NAR(Base):
|
|||
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
|
||||
# append stop tokens for AR
|
||||
# could technically do it in the .inputs call
|
||||
for i in range(batch_size):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
if quant_levels[i] >= resps_list[i].shape[-1]:
|
||||
quant_levels[i] = resps_list[i].shape[-1] - 1
|
||||
|
||||
if quant_levels[i] >= proms_list[i].shape[-1]:
|
||||
quant_levels[i] = proms_list[i].shape[-1] - 1
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
if quant_levels[i] > 0:
|
||||
continue
|
||||
|
||||
# append stop tokens for AR
|
||||
# could technically do it in the .inputs call
|
||||
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
||||
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
|
|
|
@ -31,6 +31,8 @@ from .arch import *
|
|||
from ..utils import wrapper as ml
|
||||
from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
|
||||
|
||||
from ..emb.qnt import encode_as_embedding
|
||||
|
||||
def _create_mask(l, device):
|
||||
"""1 is valid region and 0 is invalid."""
|
||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||
|
@ -169,6 +171,26 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
# subjugates the audio backend's embeddings
|
||||
# inherits for use of the stop token
|
||||
class AudioEmbedding_External(AudioEmbedding):
|
||||
def forward(self, input: Tensor, offset: int = 0 ) -> Tensor:
|
||||
if not input.shape[0]:
|
||||
return super().forward( input )
|
||||
|
||||
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
|
||||
has_stop_token = quant_level == 0 and input[-1] == 1024
|
||||
|
||||
if has_stop_token:
|
||||
input = input[:-1]
|
||||
|
||||
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[0].weight.dtype)
|
||||
|
||||
if has_stop_token:
|
||||
stop_token = super().forward( torch.Tensor([1024]).to(device=input.device, dtype=torch.int16), 0 )
|
||||
embedding = torch.concat( [ embedding, stop_token ] )
|
||||
|
||||
return embedding
|
||||
# per-level classification
|
||||
class AudioClassifier(nn.Module):
|
||||
def __init__(
|
||||
|
@ -269,6 +291,10 @@ class Base(nn.Module):
|
|||
def monolithic(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def use_external_audio_embeddings(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
return 1
|
||||
|
@ -366,6 +392,15 @@ class Base(nn.Module):
|
|||
l_tokens, d_model,
|
||||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
)
|
||||
elif self.use_external_audio_embeddings:
|
||||
self.proms_emb = AudioEmbedding_External(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
)
|
||||
self.resps_emb = AudioEmbedding_External(
|
||||
l_tokens, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
)
|
||||
else:
|
||||
self.proms_emb = AudioEmbedding(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
|
@ -835,6 +870,7 @@ class Base(nn.Module):
|
|||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||
embedding = None
|
||||
|
||||
# is already an embedding
|
||||
if name == "task":
|
||||
# noop
|
||||
# *maybe* inject a token for specifying task type
|
||||
|
|
Loading…
Reference in New Issue
Block a user