fixed dac

This commit is contained in:
mrq 2025-03-12 23:17:27 -05:00
parent ba5f3d19b4
commit 6ee505cffd
5 changed files with 17 additions and 4 deletions

BIN
data/qnt.dac Normal file

Binary file not shown.

View File

@ -318,7 +318,6 @@ class Model:
name: str = "ar+nar" # vanity name for the model
version: int = 5 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding, 3+ = additional embeddings
size: str | dict = "full" # preset string or explicitly defined dimensionality
resp_levels: int = 8 # RVQ-bin levels this model supports
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") (unused)
langs: int = 1 # defined languages (semi-unused)
tones: int = 1 # defined tones (unsued)
@ -382,6 +381,16 @@ class Model:
def tokens(self):
return self.audio_tokens
@property
def resp_levels(self):
if isinstance(self.size, dict) and "resp_levels" in self.size:
return self.size['resp_levels']
if cfg.audio_backend == "dac":
return 9
return 8
@property
def audio_tokens(self):
if isinstance(self.size, dict) and "audio_tokens" in self.size:
@ -1043,6 +1052,10 @@ class Config(BaseConfig):
if not isinstance( model, dict ):
continue
# was made an inherent property tied to audio_backend
if "resp_levels" in model:
del model["resp_levels"]
# to-do: prune unused keys in here too automatically
if "experimental" not in model or not model["experimental"]:
model["experimental"] = {}

View File

@ -2,7 +2,7 @@ import torch
from dac import DACFile
from audiotools import AudioSignal
from dac.utils import load_model as __load_dac_model
from dac.utils import load_model as load_dac_model
from typing import Union
from pathlib import Path

View File

@ -117,7 +117,7 @@ def _load_dac_model(device="cuda", dtype=None):
else:
raise Exception(f'unsupported sample rate: {cfg.sample_rate}')
model = __load_dac_model(**kwargs)
model = load_dac_model(**kwargs)
model = model.to(device)
model = model.eval()

View File

@ -828,7 +828,7 @@ def example_usage():
available_tasks = ["tts-nar"]
model = AR_NAR_V2(**kwargs).to(cfg.device)
steps = 500 // batch_size
steps = 250 # // batch_size
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""