From 2437a86efaef93c2d01afd1f4b535e3cd808a242 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 12 May 2024 13:02:15 -0500 Subject: [PATCH] ugh --- README.md | 21 ++++- scripts/process_dataset.py | 2 +- vall_e/config.py | 3 +- vall_e/data.py | 166 +++++++++++++++++-------------------- vall_e/emb/qnt.py | 2 +- 5 files changed, 96 insertions(+), 98 deletions(-) diff --git a/README.md b/README.md index 5f6e83e..987ced4 100755 --- a/README.md +++ b/README.md @@ -125,7 +125,7 @@ Unfortunately, efforts to train a *good* foundational model seems entirely predi #### Backend Architectures -As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported: +As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLm architectures: * `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements. * `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation. @@ -135,9 +135,24 @@ As the core of VALL-E makes use of a language model, various LLM architectures c * `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead. - Its implementation for MoE can also be utilized. * `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model - - inferencing cost is about 0.5x, and MoE is not implemented. + - has an inference penality, and MoE is not implemented. -It's recommended to use `llama` with `xformers`-based attention, as the savings are huge in comparison to even `retnet`-backed models. +For audio backends: + +* [`encodec`](https://github.com/facebookresearch/encodec): a tried-and-tested EnCodec to encode/decode audio. +* [`vocos`](https://huggingface.co/charactr/vocos-encodec-24khz): a higher quality EnCodec decoder. + - encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos` +* [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality + - **Note** models using `descript-audio-codec` at 24KHz + 6kbps will NOT converge. Unknown if 44KHz fares any better. + +`llama`-based models also support different attention backends: +* `math`: torch's SDPA's `math` implementation +* `mem_efficient`: torch's SDPA's memory efficient (`xformers` adjacent) implementation +* `flash`: torch's SDPA's flash attention implementation +* `xformers`: [facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention +* `auto`: determine the best fit from the above +* `sdpa`: integrated `LlamaSdpaAttention` attention model +* `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model ## Export diff --git a/scripts/process_dataset.py b/scripts/process_dataset.py index 65036e4..1290160 100644 --- a/scripts/process_dataset.py +++ b/scripts/process_dataset.py @@ -15,7 +15,7 @@ cfg.inference.audio_backend = "encodec" input_audio = "voices" input_metadata = "./training/metadata" -output_dataset = f"./training/data-{'2' if cfg.sample_rate else '4'}4KHz-{cfg.inference.audio_backend}" +output_dataset = f"./training/data-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.inference.audio_backend}" device = "cuda" audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc" diff --git a/vall_e/config.py b/vall_e/config.py index 4bbc306..73fb4ae 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -558,7 +558,7 @@ class Inference: amp: bool = False normalize: bool = False # do NOT enable this unless you know exactly what you're doing - audio_backend: str = "dac" + audio_backend: str = "vocos" # encodec, vocos, dac # legacy / backwards compat use_vocos: bool = True @@ -731,6 +731,7 @@ try: if cfg.dataset.use_hdf5: cfg.load_hdf5() except Exception as e: + cfg.dataset.use_hdf5 = False print("Error while parsing config YAML:", e) pass diff --git a/vall_e/data.py b/vall_e/data.py index 80585db..adf8574 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -66,7 +66,7 @@ def _get_quant_extension(): return ".dac" if cfg.inference.audio_backend == "dac" else ".qnt.pt" def _get_phone_extension(): - return ".json" if cfg.inference.audio_backend == "dac" else ".phn.txt" + return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt" def _get_quant_path(path): return _replace_file_extension(path, _get_quant_extension()) @@ -136,7 +136,7 @@ def _get_hdf5_path(path): def _get_hdf5_paths( data_dir, type="training", validate=False ): data_dir = str(data_dir) - def _validate(child): + def _validate( child ): phones = child.attrs['phonemes'] duration = child.attrs['duration'] if type not in _total_durations: @@ -145,7 +145,7 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ): return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones key = f"/{type}/{_get_hdf5_path(data_dir)}" - return [ Path(f"{key}/{child}") for child in cfg.hdf5[key].keys() if not validate or _validate(child) ] if key in cfg.hdf5 else [] + return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else [] def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ): if isinstance(path, str): @@ -906,7 +906,6 @@ def create_dataset_hdf5( skip_existing=True ): if not audio_exists or not text_exists: continue - key = f'{type}/{speaker_name}/{id}' if skip_existing and key in hf: @@ -1014,70 +1013,55 @@ def extract_dataset_hdf5( skip_existing=True ): root = str(cfg.data_dir) - def add( dir, type="training", audios=True, texts=True ): - name = str(dir) - name = name.replace(root, "data/") + def add( type="training", audios=True, texts=True ): + for group in tqdm( hf[f'{type}/data/'].keys(), desc=f"Processing {type}"): + for name in tqdm( hf[f'{type}/data/{group}'].keys(), desc=f"Processing {group}"): + (cfg.data_dir / group / name).mkdir(parents=True, exist_ok=True) - Path(f'{cfg.relpath}/{name}/').mkdir(parents=True, exist_ok=True) + for id in tqdm( hf[f'{type}/data/{group}/{name}'].keys(), desc=f"Processing {name}"): + try: + key = f'{type}/data/{group}/{name}/{id}' - if f'{type}/{name}' not in hf: - return + if key not in hf: + tqdm.write(f'Missing key: {key}') + continue - ids = [ key for key in hf[f'{type}/{name}'].keys() ] + audio_exists = "audio" in hf[key] + text_exists = "text" in hf[key] - for id in tqdm(ids, desc=f"Processing {name}"): - try: - key = f'{type}/{name}/{id}' + if not audio_exists or not text_exists: + tqdm.write(f'Missing audio/text: {key}') + continue - if key not in hf: - tqdm.write(f'Missing key: {key}') - continue + audio_path = Path(f'{root}/{group}/{name}/{id}.enc') + text_path = Path(f'{root}/{group}/{name}/{id}.json') - group = hf[key] - audio_exists = "audio" in group - text_exists = "text" in group + # audio + if audios and audio_exists and not audio_path.exists(): + qnt = hf[key]["audio"][:, :] + torch.save( qnt, audio_path ) - if not audio_exists or not text_exists: - tqdm.write(f'Missing audio/text: {key}') - continue + # text + if texts and text_exists and not text_path.exists(): + tokens = hf[key]["text"][:][1:-1] + phones = [ reverse_symmap[f'{token}'] for token in tokens ] + phones = list("".join(phones).replace(" ", " ")) - audio_path = Path(f'{cfg.relpath}/{name}/{id}.enc') - text_path = Path(f'{cfg.relpath}/{name}/{id}.json') + j = { + "text": "", + "phonemes": phones, + "language": "en" + } - # audio - if audios and audio_exists and not audio_path.exists(): - qnt = group["audio"][:, :] - torch.save( qnt, f'{cfg.relpath}/{name}/{id}.enc' ) + with open(text_path, "w", encoding="utf-8") as f: + f.write( json.dumps( j ) ) - # text - if texts and text_exists and not text_path.exists(): - tokens = group["text"][:][1:-1] - phones = [ reverse_symmap[f'{token}'] for token in tokens ] - phones = list("".join(phones).replace(" ", " ")) + except Exception as e: + raise e - j = { - "text": "", - "phonemes": phones, - "language": "en" - } - - with open(text_path, "w", encoding="utf-8") as f: - f.write( json.dumps( j ) ) - - except Exception as e: - raise e - - # training - for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"): - add( data_dir, type="training" ) - - # validation - for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'): - add( data_dir, type="validation" ) - - # noise - for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'): - add( data_dir, type="noise", texts=False ) + add( type="training" ) + add( type="validation" ) + add( type="noise", texts=False ) hf.close() @@ -1091,49 +1075,38 @@ def retokenize_dataset_hdf5( skip_existing=True ): root = str(cfg.data_dir) - def add( dir, type="training" ): - name = str(dir) - name = name.replace(root, "data/") + def add( type="training" ): + for group in tqdm( hf[f'{type}/data/'].keys(), desc=f"Processing {type}"): + for name in tqdm( hf[f'{type}/data/{group}'].keys(), desc=f"Processing {group}"): + (cfg.data_dir / group / name).mkdir(parents=True, exist_ok=True) - Path(f'{cfg.relpath}/{name}/').mkdir(parents=True, exist_ok=True) + for id in tqdm( hf[f'{type}/data/{group}/{name}'].keys(), desc=f"Processing {name}"): + try: + key = f'{type}/data/{group}/{name}/{id}' - if f'{type}/{name}' not in hf: - return + if key not in hf: + tqdm.write(f'Missing key: {key}') + continue - ids = [ key for key in hf[f'{type}/{name}'].keys() ] + if "text" not in hf[key]: + tqdm.write(f'Missing text: {key}') + continue - for id in tqdm(ids, desc=f"Processing {name}"): - try: - key = f'{type}/{name}/{id}' + # text + tokens = hf[key]["text"][:][1:-1] + content = list("".join([ reverse_symmap[f'{token}'] for token in tokens ]).replace(" ", " ")) - if key not in hf: - tqdm.write(f'Missing key: {key}') - continue + tokens = cfg.tokenizer.encode("".join(content)) + tokens = np.array(tokens).astype(np.uint8) - group = hf[key] - if not "text" in group: - tqdm.write(f'Missing text: {key}') - continue + del hf[key]['text'] + hf[key].create_dataset('text', data=tokens, compression='lzf') - tokens = group["text"][:][1:-1] - content = list("".join([ reverse_symmap[f'{token}'] for token in tokens ]).replace(" ", " ")) + except Exception as e: + raise e - tokens = cfg.tokenizer.encode("".join(content)) - tokens = np.array(tokens).astype(np.uint8) - - del group['text'] - group.create_dataset('text', data=tokens, compression='lzf') - - except Exception as e: - raise e - - # training - for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"): - add( data_dir, type="training" ) - - # validation - for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'): - add( data_dir, type="validation" ) + add( type="training" ) + add( type="validation" ) # write symmap if "symmap" in hf: @@ -1166,6 +1139,15 @@ if __name__ == "__main__": extract_dataset_hdf5() if args.action == "retokenize-hdf5": retokenize_dataset_hdf5() + elif args.action == "list-dataset": + dataset = [] + for group in os.listdir(cfg.data_dir): + for name in os.listdir(cfg.data_dir / group): + if len(os.listdir(cfg.data_dir / group / name)) == 0: + continue + dataset.append(f'{group}/{name}') + + print(dataset) elif args.action == "metadata": create_dataset_metadata() elif args.action == "sample": diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index b222f89..caa53b6 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -147,7 +147,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels): if not cfg.variable_sample_rate: # yes there's a better way, something like f'{cfg.sample.rate//1000}hz' if cfg.sample_rate == 44_000: - kwargs["model_type"] = "44kz" + kwargs["model_type"] = "44khz" elif cfg.sample_rate == 24_000: kwargs["model_type"] = "24khz" elif cfg.sample_rate == 16_000: