experimental method of using DACs quantizer ""embeddings"" to see if it helps with model quality
This commit is contained in:
parent
a8718d35a4
commit
ec5eaebcbc
|
@ -220,6 +220,7 @@ class Model:
|
||||||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
||||||
experimental: str | None = None # for now it sets things to be HF compatible
|
experimental: str | None = None # for now it sets things to be HF compatible
|
||||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||||
|
use_external_audio_embeddings: bool = False # subjugates the audio backend's encoding/decoding model for embeddings
|
||||||
|
|
||||||
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
||||||
|
|
|
@ -182,7 +182,6 @@ def unload_model():
|
||||||
_load_model.cache_clear()
|
_load_model.cache_clear()
|
||||||
_load_encodec_model.cache_clear() # because vocos can only decode
|
_load_encodec_model.cache_clear() # because vocos can only decode
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=None):
|
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=None):
|
||||||
# upcast so it won't whine
|
# upcast so it won't whine
|
||||||
|
@ -264,6 +263,31 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"):
|
||||||
def _replace_file_extension(path, suffix):
|
def _replace_file_extension(path, suffix):
|
||||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||||
|
|
||||||
|
# some experimental shit involving using the Encodec/DAC model's embeddings itself
|
||||||
|
@torch.inference_mode()
|
||||||
|
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
|
||||||
|
model = _load_model(device)
|
||||||
|
|
||||||
|
|
||||||
|
codes = codes.to(device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
if codes.dim() == 1:
|
||||||
|
codes = rearrange(codes, "t -> 1 t")
|
||||||
|
else:
|
||||||
|
codes = codes[:, quant_level]
|
||||||
|
codes = rearrange(codes, "t -> 1 t")
|
||||||
|
|
||||||
|
|
||||||
|
if cfg.audio_backend == "dac":
|
||||||
|
emb = model.quantizer.quantizers[quant_level]
|
||||||
|
|
||||||
|
x = emb.decode_code(codes)
|
||||||
|
x = emb.out_proj(x)
|
||||||
|
x = x[0].t().detach()
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
raise Exception(f'Currently only DAC is supported')
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):
|
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):
|
||||||
|
|
|
@ -18,7 +18,7 @@ from einops import rearrange
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from ..emb.qnt import trim
|
from ..emb.qnt import trim, encode_as_embedding
|
||||||
|
|
||||||
from .lora import enable_lora
|
from .lora import enable_lora
|
||||||
|
|
||||||
|
@ -106,6 +106,12 @@ class AR_NAR(Base):
|
||||||
def monolithic(self) -> bool:
|
def monolithic(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_external_audio_embeddings(self) -> bool:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.use_external_audio_embeddings
|
||||||
|
return cfg.model.use_external_audio_embeddings
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def version(self) -> int:
|
def version(self) -> int:
|
||||||
if hasattr(self, "config") and self.config:
|
if hasattr(self, "config") and self.config:
|
||||||
|
@ -191,14 +197,23 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||||
|
|
||||||
# append stop tokens for AR
|
|
||||||
# could technically do it in the .inputs call
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
# cap quant_level if it exceeds its corresponding resp/prom
|
||||||
|
if quant_levels[i] >= resps_list[i].shape[-1]:
|
||||||
|
quant_levels[i] = resps_list[i].shape[-1] - 1
|
||||||
|
|
||||||
|
if quant_levels[i] >= proms_list[i].shape[-1]:
|
||||||
|
quant_levels[i] = proms_list[i].shape[-1] - 1
|
||||||
|
|
||||||
# only apply stop token for RVQ level 0
|
# only apply stop token for RVQ level 0
|
||||||
if quant_levels[i] > 0:
|
if quant_levels[i] > 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# append stop tokens for AR
|
||||||
|
# could technically do it in the .inputs call
|
||||||
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
||||||
|
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
|
|
|
@ -31,6 +31,8 @@ from .arch import *
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
|
from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
|
||||||
|
|
||||||
|
from ..emb.qnt import encode_as_embedding
|
||||||
|
|
||||||
def _create_mask(l, device):
|
def _create_mask(l, device):
|
||||||
"""1 is valid region and 0 is invalid."""
|
"""1 is valid region and 0 is invalid."""
|
||||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||||
|
@ -169,6 +171,26 @@ class AudioEmbedding(nn.Module):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
# subjugates the audio backend's embeddings
|
||||||
|
# inherits for use of the stop token
|
||||||
|
class AudioEmbedding_External(AudioEmbedding):
|
||||||
|
def forward(self, input: Tensor, offset: int = 0 ) -> Tensor:
|
||||||
|
if not input.shape[0]:
|
||||||
|
return super().forward( input )
|
||||||
|
|
||||||
|
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
|
||||||
|
has_stop_token = quant_level == 0 and input[-1] == 1024
|
||||||
|
|
||||||
|
if has_stop_token:
|
||||||
|
input = input[:-1]
|
||||||
|
|
||||||
|
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[0].weight.dtype)
|
||||||
|
|
||||||
|
if has_stop_token:
|
||||||
|
stop_token = super().forward( torch.Tensor([1024]).to(device=input.device, dtype=torch.int16), 0 )
|
||||||
|
embedding = torch.concat( [ embedding, stop_token ] )
|
||||||
|
|
||||||
|
return embedding
|
||||||
# per-level classification
|
# per-level classification
|
||||||
class AudioClassifier(nn.Module):
|
class AudioClassifier(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -269,6 +291,10 @@ class Base(nn.Module):
|
||||||
def monolithic(self) -> bool:
|
def monolithic(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_external_audio_embeddings(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def version(self) -> int:
|
def version(self) -> int:
|
||||||
return 1
|
return 1
|
||||||
|
@ -366,6 +392,15 @@ class Base(nn.Module):
|
||||||
l_tokens, d_model,
|
l_tokens, d_model,
|
||||||
levels=self.n_resp_levels if self.version > 3 else None,
|
levels=self.n_resp_levels if self.version > 3 else None,
|
||||||
)
|
)
|
||||||
|
elif self.use_external_audio_embeddings:
|
||||||
|
self.proms_emb = AudioEmbedding_External(
|
||||||
|
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||||
|
sums=audio_embedding_sums,
|
||||||
|
)
|
||||||
|
self.resps_emb = AudioEmbedding_External(
|
||||||
|
l_tokens, d_model,
|
||||||
|
sums=audio_embedding_sums,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.proms_emb = AudioEmbedding(
|
self.proms_emb = AudioEmbedding(
|
||||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||||
|
@ -835,6 +870,7 @@ class Base(nn.Module):
|
||||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||||
embedding = None
|
embedding = None
|
||||||
|
|
||||||
|
# is already an embedding
|
||||||
if name == "task":
|
if name == "task":
|
||||||
# noop
|
# noop
|
||||||
# *maybe* inject a token for specifying task type
|
# *maybe* inject a token for specifying task type
|
||||||
|
|
Loading…
Reference in New Issue
Block a user