From 4f5c9e518a9499ec190d29e2e891d94369da2b96 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 18 Apr 2024 13:32:41 -0500 Subject: [PATCH] actually use the passed-through sample rate from encode for DAC because it does its own resampling I guess --- scripts/process_libritts.py | 34 ++++++++++++------ scripts/process_old_dataaset.py | 62 +++++++++++++++++++++++++++++++++ vall_e/config.py | 2 +- vall_e/data.py | 31 ++++++++++------- vall_e/emb/g2p.py | 3 ++ vall_e/emb/qnt.py | 4 +-- vall_e/models/ar_nar.py | 4 +-- 7 files changed, 111 insertions(+), 29 deletions(-) create mode 100644 scripts/process_old_dataaset.py diff --git a/scripts/process_libritts.py b/scripts/process_libritts.py index 8235f47..b4ce768 100755 --- a/scripts/process_libritts.py +++ b/scripts/process_libritts.py @@ -27,17 +27,29 @@ for dataset_name in os.listdir(f'./{input_dataset}/'): if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}'): continue for filename in os.listdir(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}'): - os.rename(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}', f'./{output_dataset}/{speaker_id}/{filename}') + # os.rename(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}', f'./{output_dataset}/{speaker_id}/{filename}') - if ".original.txt" in filename: - txts.append(Path(f'./{output_dataset}/{speaker_id}/{filename}')) - if ".wav" in filename: - wavs.append(Path(f'./{output_dataset}/{speaker_id}/{filename}')) + inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}') + outpath = Path(f'./{output_dataset}/{speaker_id}/{filename}') + + if ".original.txt" in filename and not _replace_file_extension(outpath, ".json").exists(): + txts.append([inpath, outpath]) + if ".wav" in filename and not _replace_file_extension(outpath, ".dac").exists(): + wavs.append([inpath, outpath]) -for path in tqdm(txts, desc="Phonemizing..."): - phones = valle_phonemize(open(path, "r", encoding="utf-8").read()) - open(_replace_file_extension(path, ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones)) +for paths in tqdm(txts, desc="Phonemizing..."): + text = open(paths[0], "r", encoding="utf-8").read() + phones = valle_phonemize(text) + data = { + "text": text, + "phonemes": phones, + "language": "english", + } + open(_replace_file_extension(paths[1], ".json"), 'w', encoding='utf-8').write(json.dumps(data)) + #phones = valle_phonemize(open(paths[0], "r", encoding="utf-8").read()) + #open(_replace_file_extension(paths[1], ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones)) -for path in tqdm(wavs, desc="Quantizing..."): - qnt = valle_quantize(path, device=device) - torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt")) +for paths in tqdm(wavs, desc="Quantizing..."): + qnt = valle_quantize(paths[0], device=device) + qnt.save(_replace_file_extension(paths[1], ".dac")) + #torch.save(qnt.cpu(), _replace_file_extension(paths[1], ".qnt.pt")) diff --git a/scripts/process_old_dataaset.py b/scripts/process_old_dataaset.py new file mode 100644 index 0000000..80608a6 --- /dev/null +++ b/scripts/process_old_dataaset.py @@ -0,0 +1,62 @@ +import os +import json +import torch + +from tqdm.auto import tqdm +from pathlib import Path +from vall_e.emb.g2p import encode as valle_phonemize +from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension + +input_audio = "voices" +input_metadata = "metadata" +output_dataset = "training" + +device = "cuda" + +txts = [] +wavs = [] + +for dataset_name in os.listdir(f'./{input_audio}/'): + if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): + continue + + for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"): + if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): + continue + + os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True) + for filename in os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/'): + inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}') + outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{filename}') + + metadata_json = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json') + + if not metadata_json.exists() or not inpath.exist(): + print("Does not exist:", metadata_json, inpath) + continue + + if ".wav" not in filename and ".mp3" not in filename: + continue + + if not _replace_file_extension(outpath, ".json").exists(): + txts.push([ inpath, outpath ]) + + if not _replace_file_extension(outpath, ".dac").exists(): + wavs.push([ inpath, outpath ]) + +for paths in tqdm(txts, desc="Phonemizing..."): + text = open(paths[0], "r", encoding="utf-8").read() + phones = valle_phonemize(text) + data = { + "text": text, + "phonemes": phones, + "language": "english", + } + open(_replace_file_extension(paths[1], ".json"), 'w', encoding='utf-8').write(json.dumps(data)) + #phones = valle_phonemize(open(paths[0], "r", encoding="utf-8").read()) + #open(_replace_file_extension(paths[1], ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones)) + +for paths in tqdm(wavs, desc="Quantizing..."): + qnt = valle_quantize(paths[0], device=device) + qnt.save(_replace_file_extension(paths[1], ".dac")) + #torch.save(qnt.cpu(), _replace_file_extension(paths[1], ".qnt.pt")) diff --git a/vall_e/config.py b/vall_e/config.py index 4605609..4f911ef 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -484,7 +484,7 @@ class Inference: amp: bool = False normalize: bool = False # do NOT enable this unless you know exactly what you're doing - audio_backend: str = "vocos" + audio_backend: str = "dac" # legacy / backwards compat use_vocos: bool = True diff --git a/vall_e/data.py b/vall_e/data.py index 8984a94..e922461 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -836,27 +836,32 @@ def create_dataset_hdf5( skip_existing=True ): if "audio" in group: del group["audio"] group.create_dataset('audio', data=qnt.numpy(), compression='lzf') - group.attrs['duration'] = qnt.shape[0] / 75 - metadata[id]["duration"] = qnt.shape[0] / 75 + group.attrs['duration'] = qnt.shape[0] # / 75 + metadata[id]["duration"] = qnt.shape[0] # / 75 else: group.attrs['duration'] = 0 metadata[id]["duration"] = 0 # text if texts: - content = open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") .read().split(" ") - phones = [f""] + [ " " if not p else p for p in content ] + [f""] - for s in set(phones): - if s not in symmap: - symmap[s] = len(symmap.keys()) + """ + content = open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") .read().split(" ") + phones = [f""] + [ " " if not p else p for p in content ] + [f""] + for s in set(phones): + if s not in symmap: + symmap[s] = len(symmap.keys()) - phn = [ symmap[s] for s in phones ] + phn = [ symmap[s] for s in phones ] - if "text" in group: - del group["text"] - group.create_dataset('text', data=phn, compression='lzf', chunks=True) - group.attrs['phonemes'] = len(phn) - metadata[id]["phones"] = len(phn) + if "text" in group: + del group["text"] + + group.create_dataset('text', data=phn, compression='lzf', chunks=True) + group.create_dataset('transcription', data=txt, compression='lzf', chunks=True) + """ + + group.attrs['phonemes'] = len(phn) + metadata[id]["phones"] = len(phn) else: group.attrs['phonemes'] = 0 metadata[id]["phones"] = 0 diff --git a/vall_e/emb/g2p.py b/vall_e/emb/g2p.py index 3c64536..f1308b0 100755 --- a/vall_e/emb/g2p.py +++ b/vall_e/emb/g2p.py @@ -49,6 +49,8 @@ def encode(text: str, language="en-us", backend="auto") -> list[str]: tokens = phonemize( text, language=language, strip=True, preserve_punctuation=True, with_stress=True ) tokens = list(tokens[0]) + return tokens + """ tokenized = " ".join( tokens ) merges = [ "\u02C8", "\u02CC", "\u02D0" ] @@ -56,6 +58,7 @@ def encode(text: str, language="en-us", backend="auto") -> list[str]: tokenized = tokenized.replace( f' {merge}', merge ) return tokenized.split(" ") + """ @torch.no_grad() diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index a2fd12e..40827ad 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -262,10 +262,10 @@ def _replace_file_extension(path, suffix): @torch.inference_mode() -def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=False): +def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True): if cfg.inference.audio_backend == "dac": model = _load_dac_model(device, levels=levels) - signal = AudioSignal(wav, sample_rate=model.sample_rate) + signal = AudioSignal(wav, sample_rate=sr) artifact = model.compress(signal, 5.0, verbose=False, n_quantizers=levels if isinstance(levels, int) else None) return artifact.codes if not return_metadata else artifact diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index c868dca..a6d995e 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -384,7 +384,7 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - steps = 500 + steps = 750 optimizer = ml.Prodigy(model.parameters(), lr=1.0) #optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2) #optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) @@ -427,7 +427,7 @@ def example_usage(): print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @torch.inference_mode() - def sample( name, steps=600 ): + def sample( name, steps=1000 ): if cfg.inference.audio_backend == "dac" and name == "init": return