From b5bec0c9cefa446bb30c652d7d0a2d1be6b0a6e4 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 18 Sep 2024 20:19:46 -0500 Subject: [PATCH] oops, turns out these are not split by speaker names already........ (also added sampling the dataset in the webui for easy viewing) --- scripts/process_emilia.py | 32 ++++++++++++++-------------- vall_e/utils/io.py | 6 +++++- vall_e/webui.py | 45 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 17 deletions(-) diff --git a/scripts/process_emilia.py b/scripts/process_emilia.py index a291959..7a9bbe7 100644 --- a/scripts/process_emilia.py +++ b/scripts/process_emilia.py @@ -100,29 +100,29 @@ def process( group_name = "Emilia" - for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{language}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {language}"): - if not os.path.isdir(f'./{input_audio}/{language}/{speaker_id}'): - print("Is not dir:", f'./{input_audio}/{language}/{speaker_id}') + for speaker_group in tqdm(process_items(os.listdir(f'./{input_audio}/{language}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {language}"): + if not os.path.isdir(f'./{input_audio}/{language}/{speaker_group}'): + print("Is not dir:", f'./{input_audio}/{language}/{speaker_group}') continue - if speaker_id in ignore_speakers: + if speaker_group in ignore_speakers: continue - if only_speakers and speaker_id not in only_speakers: + if only_speakers and speaker_group not in only_speakers: continue - os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True) + os.makedirs(f'./{output_dataset}/{group_name}/{speaker_group}/', exist_ok=True) - if f'{group_name}/{speaker_id}' not in dataset: - dataset.append(f'{group_name}/{speaker_id}') + if f'{group_name}/{speaker_group}' not in dataset: + dataset.append(f'{group_name}/{speaker_group}') txts = [] wavs = [] - for filename in os.listdir(f'./{input_audio}/{language}/{speaker_id}'): + for filename in os.listdir(f'./{input_audio}/{language}/{speaker_group}'): if ".mp3" not in filename: continue - inpath = Path(f'./{input_audio}/{language}/{speaker_id}/{filename}') + inpath = Path(f'./{input_audio}/{language}/{speaker_group}/{filename}') jsonpath = _replace_file_extension(inpath, ".json") if not inpath.exists() or not jsonpath.exists(): missing["audio"].append(str(inpath)) @@ -130,15 +130,15 @@ def process( extension = os.path.splitext(filename)[-1][1:] fname = filename.replace(f'.{extension}', "") - - waveform, sample_rate = None, None - - outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}') - metadata = json.load(open(jsonpath, "r", encoding="utf-8")) - if "text" not in metadata: continue + waveform, sample_rate = None, None + metadata = json.load(open(jsonpath, "r", encoding="utf-8")) + speaker_id = metadata["speaker"] + outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}') + + if _replace_file_extension(outpath, audio_extension).exists(): continue diff --git a/vall_e/utils/io.py b/vall_e/utils/io.py index 8030ff5..b2d17c5 100644 --- a/vall_e/utils/io.py +++ b/vall_e/utils/io.py @@ -13,9 +13,13 @@ except: from .utils import truncate_json -def json_stringify( data, truncate=False ): +def json_stringify( data, truncate=False, pretty=False ): if truncate: return truncate_json( json.dumps( data ) ) + if pretty: + if use_orjson: + return json.dumps( data, option=json.OPT_INDENT_2 ).decode('utf-8') + return json.dumps( data, indent='\t' ).decode('utf-8') return json.dumps( data ) def json_parse( string ): diff --git a/vall_e/webui.py b/vall_e/webui.py index ac3a391..9ca3e05 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -5,6 +5,10 @@ import argparse import random import tempfile import functools + +import torch +import numpy as np + from datetime import datetime import torchaudio @@ -16,6 +20,8 @@ from pathlib import Path from .inference import TTS, cfg from .train import train from .utils import get_devices, setup_logging +from .utils.io import json_read, json_stringify +from .emb.qnt import decode_to_wave tts = None @@ -23,6 +29,7 @@ layout = {} layout["inference_tts"] = {} layout["inference_stt"] = {} layout["training"] = {} +layout["dataset"] = {} layout["settings"] = {} for k in layout.keys(): @@ -90,6 +97,29 @@ def load_model( yaml, device, dtype, attention ): raise gr.Error(e) gr.Info(f"Loaded model") +def get_speakers(): + return cfg.dataset.training + +#@gradio_wrapper(inputs=layout["dataset"]["inputs"].keys()) +def load_sample( speaker ): + metadata_path = cfg.metadata_dir / f'{speaker}.json' + metadata = json_read( metadata_path ) + if not metadata: + raise gr.Error(f"Metadata not found: {metadata_path}") + + key = random.choice( list(metadata.keys()) ) + path = cfg.data_dir / speaker / f'{key}.enc' # to-do: get proper file extension + data = json_stringify( metadata[key], pretty=True ) + wav, sr = None, None + + if path.exists(): + artifact = np.load(path, allow_pickle=True)[()] + codes = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=cfg.device) + wav, sr = decode_to_wave( codes ) + wav = wav.squeeze(0).cpu().numpy() + + return data, (sr, wav) + def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention="auto"): global tts @@ -429,6 +459,21 @@ with ui: ) """ + with gr.Tab("Dataset"): + with gr.Row(): + with gr.Column(scale=7): + layout["dataset"]["outputs"]["transcription"] = gr.Textbox(lines=5, label="Sample Metadata") + with gr.Column(scale=1): + layout["dataset"]["inputs"]["speaker"] = gr.Dropdown(choices=get_speakers(), label="Speakers") + layout["dataset"]["outputs"]["audio"] = gr.Audio(label="Output") + layout["dataset"]["buttons"]["sample"] = gr.Button(value="Sample") + + layout["dataset"]["buttons"]["sample"].click( + fn=load_sample, + inputs=[ x for x in layout["dataset"]["inputs"].values() if x is not None], + outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None], + ) + with gr.Tab("Settings"): with gr.Row(): with gr.Column(scale=7):