diff --git a/vall_e/config.py b/vall_e/config.py index 68ca9c4..4605609 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -484,7 +484,12 @@ class Inference: amp: bool = False normalize: bool = False # do NOT enable this unless you know exactly what you're doing + audio_backend: str = "vocos" + + # legacy / backwards compat use_vocos: bool = True + use_encodec: bool = True + use_dac: bool = True recurrent_chunk_size: int = 0 recurrent_forward: bool = False @@ -576,22 +581,30 @@ class Config(_Config): self.dataset.use_hdf5 = False def format( self ): + #if not isinstance(self.dataset, type): self.dataset = Dataset(**self.dataset) + self.dataset.training = [ Path(dir) for dir in self.dataset.training ] + self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ] + self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ] + + #if not isinstance(self.model, type): if self.models is not None: self.model = Model(**next(iter(self.models))) else: self.model = Model(**self.model) - self.hyperparameters = Hyperparameters(**self.hyperparameters) - self.evaluation = Evaluation(**self.evaluation) - self.trainer = Trainer(**self.trainer) - self.inference = Inference(**self.inference) - self.bitsandbytes = BitsAndBytes(**self.bitsandbytes) - self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed) - - self.dataset.training = [ Path(dir) for dir in self.dataset.training ] - self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ] - self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ] + #if not isinstance(self.hyperparameters, type): + self.hyperparameters = Hyperparameters(**self.hyperparameters) + #if not isinstance(self.evaluation, type): + self.evaluation = Evaluation(**self.evaluation) + #if not isinstance(self.trainer, type): + self.trainer = Trainer(**self.trainer) + if not isinstance(self.trainer.deepspeed, type): + self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed) + #if not isinstance(self.inference, type): + self.inference = Inference(**self.inference) + #if not isinstance(self.bitsandbytes, type): + self.bitsandbytes = BitsAndBytes(**self.bitsandbytes) cfg = Config.from_cli() diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 242d852..5c28673 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -9,20 +9,89 @@ from functools import cache from pathlib import Path -from encodec import EncodecModel -from encodec.utils import convert_audio from einops import rearrange from torch import Tensor from tqdm import tqdm +try: + from encodec import EncodecModel + from encodec.utils import convert_audio +except Exception as e: + cfg.inference.use_encodec = False + try: from vocos import Vocos except Exception as e: cfg.inference.use_vocos = False +try: + from dac import DACFile + from audiotools import AudioSignal + from dac.utils import load_model as __load_dac_model + + """ + Patch decode to skip things related to the metadata (namely the waveform trimming) + So far it seems the raw waveform can just be returned without any post-processing + A smart implementation would just reuse the values from the input prompt + """ + from dac.model.base import CodecMixin + + @torch.no_grad() + def CodecMixin_decompress( + self, + obj: Union[str, Path, DACFile], + verbose: bool = False, + ) -> AudioSignal: + self.eval() + if isinstance(obj, (str, Path)): + obj = DACFile.load(obj) + + original_padding = self.padding + self.padding = obj.padding + + range_fn = range if not verbose else tqdm.trange + codes = obj.codes + original_device = codes.device + chunk_length = obj.chunk_length + recons = [] + + for i in range_fn(0, codes.shape[-1], chunk_length): + c = codes[..., i : i + chunk_length].to(self.device) + z = self.quantizer.from_codes(c)[0] + r = self.decode(z) + recons.append(r.to(original_device)) + + recons = torch.cat(recons, dim=-1) + recons = AudioSignal(recons, self.sample_rate) + + # to-do, original implementation + """ + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness + + recons.normalize(obj.input_db) + resample_fn(obj.sample_rate) + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape( + -1, obj.channels, obj.original_length + ) + """ + self.padding = original_padding + return recons + + CodecMixin.decompress = CodecMixin_decompress + +except Exception as e: + cfg.inference.use_dac = False + @cache def _load_encodec_model(device="cuda", levels=cfg.model.max_levels): - # Instantiate a pretrained EnCodec model assert cfg.sample_rate == 24_000 # too lazy to un-if ladder this shit @@ -34,8 +103,14 @@ def _load_encodec_model(device="cuda", levels=cfg.model.max_levels): elif levels == 8: bandwidth_id = 6.0 - model = EncodecModel.encodec_model_24khz().to(device) + # Instantiate a pretrained EnCodec model + model = EncodecModel.encodec_model_24khz() model.set_target_bandwidth(bandwidth_id) + + model = model.to(device) + model = model.eval() + + # extra metadata model.bandwidth_id = bandwidth_id model.sample_rate = cfg.sample_rate model.normalize = cfg.inference.normalize @@ -49,6 +124,7 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels): model = Vocos.from_pretrained("charactr/vocos-encodec-24khz") model = model.to(device) + model = model.eval() # too lazy to un-if ladder this shit bandwidth_id = 2 @@ -59,6 +135,7 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels): elif levels == 8: bandwidth_id = 2 + # extra metadata model.bandwidth_id = torch.tensor([bandwidth_id], device=device) model.sample_rate = cfg.sample_rate model.backend = "vocos" @@ -66,25 +143,48 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels): return model @cache -def _load_model(device="cuda", vocos=cfg.inference.use_vocos, levels=cfg.model.max_levels): - if vocos: - model = _load_vocos_model(device, levels=levels) +def _load_dac_model(device="cuda", levels=cfg.model.max_levels): + kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest") + + # yes there's a better way, something like f'{cfg.sample.rate//1000}hz' + if cfg.sample_rate == 44_000: + kwargs["model_type"] = "44kz" + elif cfg.sample_rate == 24_000: + kwargs["model_type"] = "24khz" + elif cfg.sample_rate == 16_000: + kwargs["model_type"] = "16khz" else: - model = _load_encodec_model(device, levels=levels) + raise Exception(f'unsupported sample rate: {cfg.sample_rate}') + + model = __load_dac_model(**kwargs) + model = model.to(device) + model = model.eval() + + # extra metadata + model.sample_rate = cfg.sample_rate + model.backend = "dac" return model +@cache +def _load_model(device="cuda", backend=cfg.inference.audio_backend, levels=cfg.model.max_levels): + if backend == "dac": + return _load_dac_model(device, levels=levels) + if backend == "vocos": + return _load_vocos_model(device, levels=levels) + + return _load_encodec_model(device, levels=levels) + def unload_model(): _load_model.cache_clear() - _load_encodec_model.cache_clear() + _load_encodec_model.cache_clear() # because vocos can only decode @torch.inference_mode() -def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels): - """ - Args: - codes: (b q t) - """ +def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=None): + # upcast so it won't whine + if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8: + codes = codes.to(torch.int32) # expand if we're given a raw 1-RVQ stream if codes.dim() == 1: @@ -96,21 +196,49 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels): codes = rearrange(codes, "t q -> 1 q t") assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}' + + # load the model model = _load_model(device, levels=levels) - # upcast so it won't whine - if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8: - codes = codes.to(torch.int32) + # DAC uses a different pathway + if model.backend == "dac": + if metadata is None: + metadata = dict( + chunk_length=416, + original_length=0, + input_db=-12, + channels=1, + sample_rate=model.sample_rate, + padding=False, + dac_version='1.0.0', + ) + # generate object with copied metadata + artifact = DACFile( + codes = codes, + # yes I can **kwargs from a dict but what if I want to pass the actual DACFile.metadata from elsewhere + chunk_length = metadata["chunk_length"] if isinstance(metadata, dict) else metadata.chunk_length, + original_length = metadata["original_length"] if isinstance(metadata, dict) else metadata.original_length, + input_db = metadata["input_db"] if isinstance(metadata, dict) else metadata.input_db, + channels = metadata["channels"] if isinstance(metadata, dict) else metadata.channels, + sample_rate = metadata["sample_rate"] if isinstance(metadata, dict) else metadata.sample_rate, + padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding, + dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version, + ) + + return model.decompress(artifact, verbose=False).audio_data[0], model.sample_rate + kwargs = {} if model.backend == "vocos": x = model.codes_to_features(codes[0]) kwargs['bandwidth_id'] = model.bandwidth_id else: + # encodec will decode as a batch x = [(codes.to(device), None)] wav = model.decode(x, **kwargs) + # encodec will decode as a batch if model.backend == "encodec": wav = wav[0] @@ -131,13 +259,14 @@ def _replace_file_extension(path, suffix): @torch.inference_mode() -def encode(wav: Tensor, sr: int = 24_000, device="cuda", levels=cfg.model.max_levels): - """ - Args: - wav: (t) - sr: int - """ +def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=False): + if cfg.inference.audio_backend == "dac": + model = _load_dac_model(device, levels=levels) + signal = AudioSignal(wav, sample_rate=model.sample_rate) + artifact = model.compress(signal, 5.0, verbose=False, n_quantizers=levels if isinstance(levels, int) else None) + return artifact.codes if not return_metadata else artifact + # vocos does not encode wavs to encodecs, so just use normal encodec model = _load_encodec_model(device, levels=levels) wav = wav.unsqueeze(0) wav = convert_audio(wav, sr, model.sample_rate, model.channels) @@ -180,8 +309,9 @@ def encode_from_file(path, device="cuda"): return qnt -# Helper Functions - +""" +Helper Functions +""" # trims from the start, up to `target` def trim( qnt, target ): length = max( qnt.shape[0], qnt.shape[1] ) @@ -208,7 +338,7 @@ def trim_random( qnt, target ): end = start + target if end >= length: start = length - target - end = length + end = length return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end] @@ -233,13 +363,14 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ): decoded[i] = decoded[i] * scale[i] combined = sum(decoded) / len(decoded) - return encode(combined, 24_000, device="cpu", levels=levels)[0].t() + return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t() def main(): parser = argparse.ArgumentParser() parser.add_argument("folder", type=Path) parser.add_argument("--suffix", default=".wav") parser.add_argument("--device", default="cuda") + parser.add_argument("--backend", default="encodec") args = parser.parse_args() device = args.device diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 4786fcf..c868dca 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -336,7 +336,9 @@ def example_usage(): phones = [f""] + [ " " if not p else p for p in split ] + [f""] return torch.tensor([*map(symmap.get, phones)]) - qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.model.prom_levels].to(device) + qnt = torch.load(f'data/qnt{".dac" if cfg.inference.audio_backend == "dac" else ""}.pt')[0].t()[:, :cfg.model.prom_levels].to(device) + + print(qnt.shape) cfg.hyperparameters.gradient_accumulation_steps = 1 @@ -426,11 +428,15 @@ def example_usage(): @torch.inference_mode() def sample( name, steps=600 ): + if cfg.inference.audio_backend == "dac" and name == "init": + return + engine.eval() resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) - for i, o in enumerate(resps_list): - _ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device) + if cfg.inference.audio_backend != "dac": + for i, o in enumerate(resps_list): + _ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device) resps_list = [r.unsqueeze(-1) for r in resps_list] resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )