From c6e0f905b535c91b8ec705f43eccf096cb08904c Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 8 May 2024 02:11:38 -0500 Subject: [PATCH] final tweaks (again) before training restarts --- scripts/process_dataset.py | 3 ++- vall_e/data.py | 41 ++++++++++++++++++++++++++++---------- vall_e/emb/qnt.py | 10 ++++++++-- 3 files changed, 40 insertions(+), 14 deletions(-) diff --git a/scripts/process_dataset.py b/scripts/process_dataset.py index 1ae4f50..52b2908 100644 --- a/scripts/process_dataset.py +++ b/scripts/process_dataset.py @@ -64,7 +64,8 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): waveform, sample_rate = None, None language = metadata[filename]["language"] if "language" in metadata[filename] else "english" - dataset.append(f'{dataset_name}/{speaker_id}') + if f'{dataset_name}/{speaker_id}' not in dataset: + dataset.append(f'{dataset_name}/{speaker_id}') if len(metadata[filename]["segments"]) == 0 or not use_slices: outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}.{extension}') diff --git a/vall_e/data.py b/vall_e/data.py index aa5f2a4..f6e02c9 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -763,7 +763,13 @@ def create_dataset_metadata( skip_existing=True ): name = str(dir) name = name.replace(root, "") - metadata_path = Path(f"{metadata_root}/{name}.json") + # yucky + speaker_name = name + if "LbriTTS-R" in speaker_name: + speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox") + + metadata_path = Path(f"{metadata_root}/{speaker_name}.json") + metadata_path.parents[0].mkdir(parents=True, exist_ok=True) metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) @@ -783,9 +789,9 @@ def create_dataset_metadata( skip_existing=True ): if not audio_exists or not text_exists: continue - key = f'{type}/{name}/{id}' + key = f'{type}/{speaker_name}/{id}' - if skip_existing and key in metadata: + if skip_existing and id in metadata: continue if id not in metadata: @@ -816,15 +822,18 @@ def create_dataset_metadata( skip_existing=True ): json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) content = json_metadata["phonemes"] txt = json_metadata["text"] + lang = json_metadata["language"][:2] else: content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ") txt = "" + lang = "en" phn = cfg.tokenizer.encode("".join(content)) phn = np.array(phn).astype(np.uint8) metadata[id]["phones"] = len(phn) metadata[id]["transcription"] = txt + metadata[id]["language"] = lang except Exception as e: #raise e print(id, e) @@ -849,20 +858,25 @@ def create_dataset_metadata( skip_existing=True ): def create_dataset_hdf5( skip_existing=True ): cfg.dataset.use_hdf5 = True cfg.load_hdf5(write=True) + hf = cfg.hdf5 symmap = get_phone_symmap() root = str(cfg.data_dir) metadata_root = str(cfg.metadata_dir) - hf = cfg.hdf5 - cfg.metadata_dir.mkdir(parents=True, exist_ok=True) def add( dir, type="training", audios=True, texts=True ): name = str(dir) name = name.replace(root, "") + + # yucky + speaker_name = name + if "LbriTTS-R" in speaker_name: + speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox") - metadata_path = Path(f"{metadata_root}/{name}.json") + metadata_path = Path(f"{metadata_root}/{speaker_name}.json") + metadata_path.parents[0].mkdir(parents=True, exist_ok=True) metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) @@ -882,7 +896,8 @@ def create_dataset_hdf5( skip_existing=True ): if not audio_exists or not text_exists: continue - key = f'{type}/{name}/{id}' + + key = f'{type}/{speaker_name}/{id}' """ if skip_existing and key in hf: @@ -893,7 +908,7 @@ def create_dataset_hdf5( skip_existing=True ): group.attrs['id'] = id group.attrs['type'] = type - group.attrs['speaker'] = name + group.attrs['speaker'] = speaker_name if id not in metadata: metadata[id] = {} @@ -930,9 +945,11 @@ def create_dataset_hdf5( skip_existing=True ): json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) content = json_metadata["phonemes"] txt = json_metadata["text"] + lang = json_metadata["language"][:2] else: content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ") txt = "" + lang = "en" phn = cfg.tokenizer.encode("".join(content)) phn = np.array(phn).astype(np.uint8) @@ -942,9 +959,11 @@ def create_dataset_hdf5( skip_existing=True ): group.attrs['phonemes'] = len(phn) group.attrs['transcription'] = txt + group.attrs['language'] = lang metadata[id]["phones"] = len(phn) metadata[id]["transcription"] = txt + metadata[id]["language"] = lang else: group.attrs['phonemes'] = 0 metadata[id]["phones"] = 0 @@ -958,15 +977,15 @@ def create_dataset_hdf5( skip_existing=True ): # training - for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"): + for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"): add( data_dir, type="training" ) # validation - for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'): + for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'): add( data_dir, type="validation" ) # noise - for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'): + for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'): add( data_dir, type="noise", texts=False ) # write symmap diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 5b43f6c..8229a51 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -272,11 +272,17 @@ def _replace_file_extension(path, suffix): @torch.inference_mode() def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True): if cfg.inference.audio_backend == "dac": - model = _load_dac_model(device, levels=levels) + model = _load_dac_model(device, levels=levels ) signal = AudioSignal(wav, sample_rate=sr) - artifact = model.compress(signal, 5.0, verbose=False, n_quantizers=levels if isinstance(levels, int) else None) + + if not isinstance(levels, int): + levels = 8 if model.model_type == "24khz" else None + + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=False): # or True for about 2x speed, not enabling by default for systems that do not have bfloat16 + artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels) # trim to 8 codebooks if 24Khz + # probably redundant with levels, should rewrite logic eventuall if model.model_type == "24khz": artifact.codes = artifact.codes[:, :8, :]