diff --git a/README.md b/README.md index 0ddd612..277a245 100755 --- a/README.md +++ b/README.md @@ -143,7 +143,8 @@ For audio backends: * [`vocos`](https://huggingface.co/charactr/vocos-encodec-24khz): a higher quality EnCodec decoder. - encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos` * [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality - - **Note** models using `descript-audio-codec` at 24KHz + 8kbps will NOT converge. Audio encoded through the 44KHz seems to work. + - **Note** models using `descript-audio-codec` at 24KHz + 8kbps will NOT converge in any manner. + - **Note** models using `descript-audio-codec` at 44KHz + 8kbps stops improving after a while. `llama`-based models also support different attention backends: * `math`: torch's SDPA's `math` implementation diff --git a/scripts/process_dataset.py b/scripts/process_dataset.py index c57b933..d2ae337 100644 --- a/scripts/process_dataset.py +++ b/scripts/process_dataset.py @@ -8,8 +8,8 @@ from pathlib import Path from vall_e.config import cfg # things that could be args -cfg.sample_rate = 44_000 -cfg.inference.audio_backend = "dac" +cfg.sample_rate = 24_000 +cfg.inference.audio_backend = "encodec" """ cfg.inference.weight_dtype = "bfloat16" cfg.inference.dtype = torch.bfloat16 diff --git a/scripts/process_libritts.py b/scripts/process_libritts.py index b4ce768..69b1711 100755 --- a/scripts/process_libritts.py +++ b/scripts/process_libritts.py @@ -1,14 +1,29 @@ import os import json import torch +import numpy as np from tqdm.auto import tqdm from pathlib import Path + +from vall_e.config import cfg + +# things that could be args +cfg.sample_rate = 24_000 +cfg.inference.audio_backend = "encodec" +""" +cfg.inference.weight_dtype = "bfloat16" +cfg.inference.dtype = torch.bfloat16 +cfg.inference.amp = True +""" + from vall_e.emb.g2p import encode as valle_phonemize from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension +audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc" + input_dataset = "LibriTTS_R" -output_dataset = "LibriTTS-Train" +output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz" device = "cuda" txts = [] @@ -32,24 +47,61 @@ for dataset_name in os.listdir(f'./{input_dataset}/'): inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}') outpath = Path(f'./{output_dataset}/{speaker_id}/{filename}') - if ".original.txt" in filename and not _replace_file_extension(outpath, ".json").exists(): - txts.append([inpath, outpath]) - if ".wav" in filename and not _replace_file_extension(outpath, ".dac").exists(): - wavs.append([inpath, outpath]) + if ".wav" in filename: # and not _replace_file_extension(outpath, ".dac").exists(): + txts.append(( + inpath, + outpath + )) -for paths in tqdm(txts, desc="Phonemizing..."): - text = open(paths[0], "r", encoding="utf-8").read() - phones = valle_phonemize(text) - data = { - "text": text, - "phonemes": phones, - "language": "english", - } - open(_replace_file_extension(paths[1], ".json"), 'w', encoding='utf-8').write(json.dumps(data)) - #phones = valle_phonemize(open(paths[0], "r", encoding="utf-8").read()) - #open(_replace_file_extension(paths[1], ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones)) +for paths in tqdm(txts, desc="Processing..."): + inpath, outpath = paths + try: + if _replace_file_extension(outpath, ".dac").exists() and _replace_file_extension(outpath, ".json").exists(): + data = json.loads(open(_replace_file_extension(outpath, ".json"), 'r', encoding='utf-8').read()) + qnt = np.load(_replace_file_extension(outpath, audio_extension), allow_pickle=True) + + if not isinstance(data["phonemes"], str): + data["phonemes"] = "".join(data["phonemes"]) -for paths in tqdm(wavs, desc="Quantizing..."): - qnt = valle_quantize(paths[0], device=device) - qnt.save(_replace_file_extension(paths[1], ".dac")) - #torch.save(qnt.cpu(), _replace_file_extension(paths[1], ".qnt.pt")) + for k, v in data.items(): + qnt[()]['metadata'][k] = v + + np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), qnt) + else: + text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read() + + phones = valle_phonemize(text) + qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device) + + if cfg.inference.audio_backend == "dac": + np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { + "codes": qnt.codes.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": qnt.original_length, + "sample_rate": qnt.sample_rate, + + "input_db": qnt.input_db.cpu().numpy().astype(np.float32), + "chunk_length": qnt.chunk_length, + "channels": qnt.channels, + "padding": qnt.padding, + "dac_version": "1.0.0", + + "text": text.strip(), + "phonemes": "".join(phones), + "language": "en", + }, + }) + else: + np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { + "codes": qnt.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": qnt.shape[0] / 75.0, + "sample_rate": cfg.sample_rate, + + "text": text.strip(), + "phonemes": "".join(phones), + "language": "en", + }, + }) + except Exception as e: + tqdm.write(f"Failed to process: {paths}: {e}") diff --git a/vall_e/config.py b/vall_e/config.py index 7cd39d8..7f03e3f 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -156,7 +156,7 @@ class Dataset: p_resp_append: float = 1.0 sample_type: str = "path" # path | speaker - + tasks_list: list[str] = field(default_factory=lambda: ["tts"]) _frames_per_second: int = 0 # allows setting your own hint @@ -166,7 +166,7 @@ class Dataset: if self._frames_per_second > 0: return self._frames_per_second - if cfg.inference.audio_backend == "dac": + if cfg.audio_backend == "dac": # using the 44KHz model with 24KHz sources has a frame rate of 41Hz if cfg.variable_sample_rate and cfg.sample_rate == 24_000: return 41 @@ -567,7 +567,7 @@ class Inference: amp: bool = False normalize: bool = False # do NOT enable this unless you know exactly what you're doing - audio_backend: str = "vocos" # encodec, vocos, dac + audio_backend: str = "" # encodec, vocos, dac # legacy / backwards compat use_vocos: bool = True @@ -628,6 +628,8 @@ class Config(_Config): sample_rate: int = 24_000 variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss + audio_backend: str = "vocos" + @property def distributed(self): return world_size() > 1 @@ -726,6 +728,9 @@ class Config(_Config): if self.trainer.backend == "local" and self.distributed: self.trainer.ddp = True + + if self.inference.audio_backend != "" and self.audio_backend == "": + self.audio_backend = self.inference.audio_backend # Preserves the old behavior class NaiveTokenizer: diff --git a/vall_e/data.py b/vall_e/data.py index f94dbf8..47b7332 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -63,10 +63,10 @@ def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) def _get_quant_extension(): - return ".dac" if cfg.inference.audio_backend == "dac" else ".enc" + return ".dac" if cfg.audio_backend == "dac" else ".enc" def _get_phone_extension(): - return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt" + return ".json" # if cfg.audio_backend == "dac" else ".phn.txt" def _get_quant_path(path): return _replace_file_extension(path, _get_quant_extension()) @@ -876,12 +876,36 @@ def create_dataset_hdf5( skip_existing=True ): if not os.path.isdir(f'{root}/{name}/'): return - # tqdm.write(f'{root}/{name}') + files = os.listdir(f'{root}/{name}/') # grab IDs for every file ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files } + """ + # rephonemizes if you fuck up and use and old tokenizer... + for id, entry in tqdm(metadata.items(), desc=f"Processing {name}"): + key = f'{type}/{speaker_name}/{id}' + + if key not in hf: + continue + + group = hf[key] + + if "phonemes" not in entry: + continue + if "text" not in group: + continue + + txt = entry["phonemes"] + phn = "".join(txt) + phn = cfg.tokenizer.encode(phn) + phn = np.array(phn).astype(np.uint8) + + del group["text"] + group.create_dataset('text', data=phn, compression='lzf') + """ + for id in tqdm(ids, desc=f"Processing {name}"): try: quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True @@ -938,8 +962,10 @@ def create_dataset_hdf5( skip_existing=True ): except Exception as e: tqdm.write(f'Error while processing {id}: {e}') + """ with open(str(metadata_path), "w", encoding="utf-8") as f: f.write( json.dumps( metadata ) ) + """ # training diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 12d9d89..bc58802 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -170,7 +170,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels): return model @cache -def _load_model(device="cuda", backend=cfg.inference.audio_backend, levels=cfg.model.max_levels): +def _load_model(device="cuda", backend=cfg.audio_backend, levels=cfg.model.max_levels): if backend == "dac": return _load_dac_model(device, levels=levels) if backend == "vocos": @@ -267,7 +267,7 @@ def _replace_file_extension(path, suffix): @torch.inference_mode() def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True): - if cfg.inference.audio_backend == "dac": + if cfg.audio_backend == "dac": model = _load_dac_model(device, levels=levels ) signal = AudioSignal(wav, sample_rate=sr) @@ -307,7 +307,7 @@ def encode_from_files(paths, device="cuda"): wav = torch.cat(wavs, dim=-1) - return encode(wav, sr, "cpu") + return encode(wav, sr, device) def encode_from_file(path, device="cuda"): if isinstance( path, list ): diff --git a/vall_e/inference.py b/vall_e/inference.py index 7a05729..cd3cde0 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -112,7 +112,18 @@ class TTS(): paths = [ Path(p) for p in paths.split(";") ] # merge inputs - res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths]) + + proms = [] + + for path in paths: + prom = qnt.encode_from_file(path) + if hasattr( prom, "codes" ): + prom = prom.codes + prom = prom[0][:, :].t().to(torch.int16) + + proms.append( prom ) + + res = torch.cat(proms) if trim_length: res = trim( res, int( cfg.dataset.frames_per_second * trim_length ) ) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index d0c6ca0..1da1843 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -319,7 +319,7 @@ class AR_NAR(Base): def example_usage(): #cfg.trainer.backend = "local" cfg.hyperparameters.gradient_accumulation_steps = 1 - if cfg.inference.audio_backend == "dac": + if cfg.audio_backend == "dac": cfg.sample_rate = 44_000 from functools import partial @@ -340,7 +340,7 @@ def example_usage(): return torch.tensor( cfg.tokenizer.encode(content) ) def _load_quants(path) -> Tensor: - if cfg.inference.audio_backend == "dac": + if cfg.audio_backend == "dac": qnt = np.load(f'{path}.dac', allow_pickle=True)[()] return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16) @@ -456,13 +456,13 @@ def example_usage(): @torch.inference_mode() def sample( name, steps=1000 ): - if cfg.inference.audio_backend == "dac" and name == "init": + if cfg.audio_backend == "dac" and name == "init": return engine.eval() resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) - if cfg.inference.audio_backend != "dac": + if cfg.audio_backend != "dac": for i, o in enumerate(resps_list): _ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)