From caad7ee3c9f02f4e036ba9df111ddce5b7806d86 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 28 Apr 2024 22:28:29 -0500 Subject: [PATCH] final tweaks, hopefully --- scripts/cleanup_dataset.py | 99 ++++++++++++++++++++++++++++ scripts/process_dataset.py | 27 ++++---- scripts/transcribe_dataset.py | 44 ++++++++++--- vall_e/config.py | 8 +++ vall_e/data.py | 117 ++++++++++++++++++---------------- vall_e/models/ar_nar.py | 6 +- 6 files changed, 220 insertions(+), 81 deletions(-) create mode 100644 scripts/cleanup_dataset.py diff --git a/scripts/cleanup_dataset.py b/scripts/cleanup_dataset.py new file mode 100644 index 0000000..083de32 --- /dev/null +++ b/scripts/cleanup_dataset.py @@ -0,0 +1,99 @@ +import os +import json +import torch +import torchaudio + +from tqdm.auto import tqdm +from pathlib import Path + +input_dataset = "metadata" +output_dataset = "metadata-cleaned" + +def pad(num, zeroes): + return str(num).zfill(zeroes+1) + +for dataset_name in os.listdir(f'./{input_dataset}/'): + if not os.path.isdir(f'./{input_dataset}/{dataset_name}/'): + print("Is not dir:", f'./{input_dataset}/{dataset_name}/') + continue + + for speaker_id in tqdm(os.listdir(f'./{input_dataset}/{dataset_name}/'), desc=f"Processing speaker: {dataset_name}"): + if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'): + print("Is not dir:", f'./{input_dataset}/{dataset_name}/{speaker_id}') + continue + + inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/whisper.json') + outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json') + + if not inpath.exists(): + continue + + if outpath.exists(): + continue + + os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True) + + try: + in_metadata = json.loads(open(inpath, 'r', encoding='utf-8').read()) + except Exception as e: + print("Failed to open metadata file:", inpath) + continue + + out_metadata = {} + speaker_metadatas = {} + + for filename, result in in_metadata.items(): + language = result["language"] if "language" in result else "en" + out_metadata[filename] = { + "segments": [], + "language": language, + "text": "", + "start": 0, + "end": 0, + } + segments = [] + text = [] + start = 0 + end = 0 + diarized = False + + for segment in result["segments"]: + # diarize split + if "speaker" in segment: + diarized = True + speaker_id = segment["speaker"] + if speaker_id not in speaker_metadatas: + speaker_metadatas[speaker_id] = {} + + if filename not in speaker_metadatas[speaker_id]: + speaker_metadatas[speaker_id][filename] = { + "segments": [], + "language": language, + "text": "", + "start": 0, + "end": 0, + } + + speaker_metadatas[speaker_id][filename]["segments"].append( segment ) + else: + segments.append( segment ) + + text.append( segment["text"] ) + start = min( start, segment["start"] ) + end = max( end, segment["end"] ) + + out_metadata[filename]["segments"] = segments + out_metadata[filename]["text"] = " ".join(text).strip() + out_metadata[filename]["start"] = start + out_metadata[filename]["end"] = end + + if len(segments) == 0: + del out_metadata[filename] + + open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata)) + + for speaker_id, out_metadata in speaker_metadatas.items(): + os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True) + outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json') + + open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata)) \ No newline at end of file diff --git a/scripts/process_dataset.py b/scripts/process_dataset.py index c7a8a94..0c628a3 100644 --- a/scripts/process_dataset.py +++ b/scripts/process_dataset.py @@ -8,26 +8,27 @@ from pathlib import Path from vall_e.emb.g2p import encode as valle_phonemize from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension -input_audio = "voice" +# things that could be args +input_audio = "voices" input_metadata = "metadata" output_dataset = "training-24K" +device = "cuda" +slice = "auto" missing = { "transcription": [], "audio": [] } -device = "cuda" - def pad(num, zeroes): return str(num).zfill(zeroes+1) -for dataset_name in os.listdir(f'./{input_audio}/'): +for dataset_name in sorted(os.listdir(f'./{input_audio}/')): if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): print("Is not dir:", f'./{input_audio}/{dataset_name}/') continue - for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"): + for speaker_id in tqdm(sorted(os.listdir(f'./{input_audio}/{dataset_name}/')), desc=f"Processing speaker in {dataset_name}"): if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}') continue @@ -36,24 +37,23 @@ for dataset_name in os.listdir(f'./{input_audio}/'): metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json') if not metadata_path.exists(): - #print("Does not exist:", metadata_path) missing["transcription"].append(str(metadata_path)) continue try: metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) except Exception as e: - #print("Failed to load metadata:", metadata_path, e) missing["transcription"].append(str(metadata_path)) continue txts = [] wavs = [] - for filename in metadata.keys(): + use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or dataset_name in ["LibriVox", "Audiobooks"] + + for filename in sorted(metadata.keys()): inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}') if not inpath.exists(): - #print("Does not exist:", inpath) missing["audio"].append(str(inpath)) continue @@ -63,9 +63,8 @@ for dataset_name in os.listdir(f'./{input_audio}/'): waveform, sample_rate = None, None language = metadata[filename]["language"] if "language" in metadata[filename] else "english" - if len(metadata[filename]["segments"]) == 0: - id = pad(0, 4) - outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}') + if len(metadata[filename]["segments"]) == 0 or not use_slices: + outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}.{extension}') text = metadata[filename]["text"] if len(text) == 0: @@ -91,8 +90,10 @@ for dataset_name in os.listdir(f'./{input_audio}/'): sample_rate )) else: + i = 0 for segment in metadata[filename]["segments"]: - id = pad(segment['id'], 4) + id = pad(i, 4) + i = i + 1 outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}') if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists(): diff --git a/scripts/transcribe_dataset.py b/scripts/transcribe_dataset.py index ca8dcfa..3814a3c 100644 --- a/scripts/transcribe_dataset.py +++ b/scripts/transcribe_dataset.py @@ -7,30 +7,35 @@ import whisperx from tqdm.auto import tqdm from pathlib import Path -device = "cuda" +# should be args batch_size = 16 +device = "cuda" dtype = "float16" -model_size = "large-v2" +model_name = "large-v3" -input_audio = "voice" +input_audio = "voices" output_dataset = "metadata" + skip_existing = True +diarize = False -model = whisperx.load_model(model_size, device, compute_type=dtype) - +# +model = whisperx.load_model(model_name, device, compute_type=dtype) align_model, align_model_metadata, align_model_language = (None, None, None) +if diarize: + diarize_model = whisperx.DiarizationPipeline(device=device) +else: + diarize_model = None def pad(num, zeroes): return str(num).zfill(zeroes+1) for dataset_name in os.listdir(f'./{input_audio}/'): if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): - print("Is not dir:", f'./{input_audio}/{dataset_name}/') continue for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"): if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): - print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}') continue outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json') @@ -46,18 +51,29 @@ for dataset_name in os.listdir(f'./{input_audio}/'): if skip_existing and filename in metadata: continue + if ".json" in filename: + continue + inpath = f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}' + + if os.path.isdir(inpath): + continue metadata[filename] = { "segments": [], "language": "", - "text": [], + "text": "", + "start": 0, + "end": 0, } audio = whisperx.load_audio(inpath) result = model.transcribe(audio, batch_size=batch_size) language = result["language"] + if language[:2] not in ["ja"]: + language = "en" + if align_model_language != language: tqdm.write(f'Loading language: {language}') align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device) @@ -68,12 +84,20 @@ for dataset_name in os.listdir(f'./{input_audio}/'): metadata[filename]["segments"] = result["segments"] metadata[filename]["language"] = language + if diarize_model is not None: + diarize_segments = diarize_model(audio) + result = whisperx.assign_word_speakers(diarize_segments, result) + text = [] + start = 0 + end = 0 for segment in result["segments"]: - id = len(text) text.append( segment["text"] ) - metadata[filename]["segments"][id]["id"] = id + start = min( start, segment["start"] ) + end = max( end, segment["end"] ) metadata[filename]["text"] = " ".join(text).strip() + metadata[filename]["start"] = start + metadata[filename]["end"] = end open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata)) \ No newline at end of file diff --git a/vall_e/config.py b/vall_e/config.py index e5be02c..580fde6 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -33,6 +33,14 @@ class _Config: def cache_dir(self): return self.relpath / ".cache" + @property + def data_dir(self): + return self.relpath / "data" + + @property + def metadata_dir(self): + return self.relpath / "metadata" + @property def ckpt_dir(self): return self.relpath / "ckpt" diff --git a/vall_e/data.py b/vall_e/data.py index cf512be..85129a1 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -85,46 +85,45 @@ def _calculate_durations( type="training" ): @cfg.diskcache() def _load_paths(dataset, type="training"): - return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") } + return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") } + +def _load_paths_from_metadata(dataset_name, type="training", validate=False): + data_dir = dataset_name if cfg.dataset.use_hdf5 else cfg.data_dir / dataset_name -def _load_paths_from_metadata(data_dir, type="training", validate=False): _fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions def _validate( entry ): - if "phones" not in entry or "duration" not in entry: - return False - phones = entry['phones'] - duration = entry['duration'] + phones = entry['phones'] if "phones" in entry else 0 + duration = entry['duration'] if "duration" in entry else 0 if type not in _total_durations: _total_durations[type] = 0 - _total_durations[type] += entry['duration'] + _total_durations[type] += duration return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones - metadata_path = data_dir / "metadata.json" + metadata_path = cfg.metadata_dir / f'{dataset_name}.json' metadata = {} + if cfg.dataset.use_metadata and metadata_path.exists(): metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if len(metadata) == 0: return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate ) + def key( dir, id ): if not cfg.dataset.use_hdf5: return data_dir / id - return f"/{type}{_get_hdf5_path(data_dir)}/{id}" + return f"/{type}/{_get_hdf5_path(data_dir)}/{id}" return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ] def _get_hdf5_path(path): - path = str(path) - if path[:2] != "./": - path = f'./{path}' - - res = path.replace(cfg.cfg_path, "") - return res + # to-do: better validation + #print(path) + return str(path) def _get_hdf5_paths( data_dir, type="training", validate=False ): data_dir = str(data_dir) @@ -137,7 +136,7 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ): _total_durations[type] += child.attrs['duration'] return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones - key = f"/{type}{_get_hdf5_path(data_dir)}" + key = f"/{type}/{_get_hdf5_path(data_dir)}" return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else [] def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ): @@ -427,6 +426,9 @@ class Dataset(_Dataset): if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) + if key not in cfg.hdf5: + raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}') + text = cfg.hdf5[key]["text"][:] resps = cfg.hdf5[key]["audio"][:, :] @@ -752,6 +754,10 @@ def create_train_val_dataloader(): # parse dataset into better to sample metadata def create_dataset_metadata(): + # need to fix + if True: + return + cfg.dataset.validate = False cfg.dataset.use_hdf5 = False @@ -805,14 +811,19 @@ def create_dataset_hdf5( skip_existing=True ): symmap = get_phone_symmap() - root = cfg.cfg_path + root = str(cfg.data_dir) + metadata_root = str(cfg.metadata_dir) hf = cfg.hdf5 + cfg.metadata_dir.mkdir(parents=True, exist_ok=True) def add( dir, type="training", audios=True, texts=True ): - name = "./" + str(dir) - name = name .replace(root, "") - metadata = {} + name = str(dir) + name = name.replace(root, "") + + metadata_path = Path(f"{metadata_root}/{name}.json") + + metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) if not os.path.isdir(f'{root}/{name}/'): return @@ -831,36 +842,38 @@ def create_dataset_hdf5( skip_existing=True ): continue key = f'{type}/{name}/{id}' - if key in hf: - if skip_existing: - continue - del hf[key] - group = hf.create_group(key) + if skip_existing and key in hf: + continue + + group = hf.create_group(key) if key not in hf else hf[key] + group.attrs['id'] = id group.attrs['type'] = type group.attrs['speaker'] = name - metadata[id] = {} + if id not in metadata: + metadata[id] = {} # audio if audios: - qnt = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] - codes = torch.from_numpy(qnt["codes"].astype(int))[0].t().to(dtype=torch.int16) - if _get_quant_extension() == ".dac": - if "audio" in group: - del group["audio"] - duration = qnt["metadata"]["original_length"] / qnt["metadata"]["sample_rate"] + dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] + qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16) + + duration = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"] metadata[id]["metadata"] = { - "original_length": qnt["metadata"]["original_length"], - "sample_rate": qnt["metadata"]["sample_rate"], + "original_length": dac["metadata"]["original_length"], + "sample_rate": dac["metadata"]["sample_rate"], } else: qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t() duration = qnt.shape[0] / 75 - group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf') + qnt = qnt.numpy().astype(np.int16) + + if "audio" not in group: + group.create_dataset('audio', data=qnt, compression='lzf') group.attrs['duration'] = duration metadata[id]["duration"] = duration @@ -870,52 +883,46 @@ def create_dataset_hdf5( skip_existing=True ): # text if texts: - if _get_quant_extension() == ".json": + if _get_phone_extension() == ".json": json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) content = json_metadata["phonemes"] + txt = json_metadata["text"] else: content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ") - - """ - phones = [f""] + [ " " if not p else p for p in content ] + [f""] - for s in set(phones): - if s not in symmap: - symmap[s] = len(symmap.keys()) - - phn = [ symmap[s] for s in phones ] - """ + txt = "" phn = cfg.tokenizer.encode("".join(content)) phn = np.array(phn).astype(np.uint8) - if "text" in group: - del group["text"] - - group.create_dataset('text', data=phn, compression='lzf', chunks=True) - group.create_dataset('transcription', data=txt, compression='lzf', chunks=True) + if "text" not in group: + group.create_dataset('text', data=phn, compression='lzf') group.attrs['phonemes'] = len(phn) + group.attrs['transcription'] = txt + metadata[id]["phones"] = len(phn) + metadata[id]["transcription"] = txt else: group.attrs['phonemes'] = 0 metadata[id]["phones"] = 0 except Exception as e: - pass + raise e + #pass - with open(dir / "metadata.json", "w", encoding="utf-8") as f: + with open(str(metadata_path), "w", encoding="utf-8") as f: f.write( json.dumps( metadata ) ) # training - for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"): + for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"): add( data_dir, type="training" ) # validation - for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'): + for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'): add( data_dir, type="validation" ) # noise - for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'): + for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'): add( data_dir, type="noise", texts=False ) # write symmap diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 49d0e93..928d873 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -340,7 +340,7 @@ def example_usage(): def _load_quants(path) -> Tensor: if cfg.inference.audio_backend == "dac": qnt = np.load(f'{path}.dac', allow_pickle=True)[()] - return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16) + return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16) qnt = _load_quants("./data/qnt") @@ -350,7 +350,7 @@ def example_usage(): tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), ] proms_list = [ - qnt[:75*3, :].to(device), + qnt.to(device), ] resps_list = [ qnt.to(device), @@ -407,7 +407,7 @@ def example_usage(): frozen_params = set() for k in list(embeddings.keys()): - if re.findall(r'_emb\.', k): + if re.findall(r'_emb.', k): frozen_params.add(k) else: del embeddings[k]