oops, turns out these are not split by speaker names already........ (also added sampling the dataset in the webui for easy viewing)

This commit is contained in:
mrq 2024-09-18 20:19:46 -05:00
parent fa9d3f6c06
commit b5bec0c9ce
3 changed files with 66 additions and 17 deletions

View File

@ -100,29 +100,29 @@ def process(
group_name = "Emilia" group_name = "Emilia"
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{language}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {language}"): for speaker_group in tqdm(process_items(os.listdir(f'./{input_audio}/{language}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {language}"):
if not os.path.isdir(f'./{input_audio}/{language}/{speaker_id}'): if not os.path.isdir(f'./{input_audio}/{language}/{speaker_group}'):
print("Is not dir:", f'./{input_audio}/{language}/{speaker_id}') print("Is not dir:", f'./{input_audio}/{language}/{speaker_group}')
continue continue
if speaker_id in ignore_speakers: if speaker_group in ignore_speakers:
continue continue
if only_speakers and speaker_id not in only_speakers: if only_speakers and speaker_group not in only_speakers:
continue continue
os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True) os.makedirs(f'./{output_dataset}/{group_name}/{speaker_group}/', exist_ok=True)
if f'{group_name}/{speaker_id}' not in dataset: if f'{group_name}/{speaker_group}' not in dataset:
dataset.append(f'{group_name}/{speaker_id}') dataset.append(f'{group_name}/{speaker_group}')
txts = [] txts = []
wavs = [] wavs = []
for filename in os.listdir(f'./{input_audio}/{language}/{speaker_id}'): for filename in os.listdir(f'./{input_audio}/{language}/{speaker_group}'):
if ".mp3" not in filename: if ".mp3" not in filename:
continue continue
inpath = Path(f'./{input_audio}/{language}/{speaker_id}/{filename}') inpath = Path(f'./{input_audio}/{language}/{speaker_group}/{filename}')
jsonpath = _replace_file_extension(inpath, ".json") jsonpath = _replace_file_extension(inpath, ".json")
if not inpath.exists() or not jsonpath.exists(): if not inpath.exists() or not jsonpath.exists():
missing["audio"].append(str(inpath)) missing["audio"].append(str(inpath))
@ -130,15 +130,15 @@ def process(
extension = os.path.splitext(filename)[-1][1:] extension = os.path.splitext(filename)[-1][1:]
fname = filename.replace(f'.{extension}', "") fname = filename.replace(f'.{extension}', "")
waveform, sample_rate = None, None
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
metadata = json.load(open(jsonpath, "r", encoding="utf-8"))
if "text" not in metadata: if "text" not in metadata:
continue continue
waveform, sample_rate = None, None
metadata = json.load(open(jsonpath, "r", encoding="utf-8"))
speaker_id = metadata["speaker"]
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
if _replace_file_extension(outpath, audio_extension).exists(): if _replace_file_extension(outpath, audio_extension).exists():
continue continue

View File

@ -13,9 +13,13 @@ except:
from .utils import truncate_json from .utils import truncate_json
def json_stringify( data, truncate=False ): def json_stringify( data, truncate=False, pretty=False ):
if truncate: if truncate:
return truncate_json( json.dumps( data ) ) return truncate_json( json.dumps( data ) )
if pretty:
if use_orjson:
return json.dumps( data, option=json.OPT_INDENT_2 ).decode('utf-8')
return json.dumps( data, indent='\t' ).decode('utf-8')
return json.dumps( data ) return json.dumps( data )
def json_parse( string ): def json_parse( string ):

View File

@ -5,6 +5,10 @@ import argparse
import random import random
import tempfile import tempfile
import functools import functools
import torch
import numpy as np
from datetime import datetime from datetime import datetime
import torchaudio import torchaudio
@ -16,6 +20,8 @@ from pathlib import Path
from .inference import TTS, cfg from .inference import TTS, cfg
from .train import train from .train import train
from .utils import get_devices, setup_logging from .utils import get_devices, setup_logging
from .utils.io import json_read, json_stringify
from .emb.qnt import decode_to_wave
tts = None tts = None
@ -23,6 +29,7 @@ layout = {}
layout["inference_tts"] = {} layout["inference_tts"] = {}
layout["inference_stt"] = {} layout["inference_stt"] = {}
layout["training"] = {} layout["training"] = {}
layout["dataset"] = {}
layout["settings"] = {} layout["settings"] = {}
for k in layout.keys(): for k in layout.keys():
@ -90,6 +97,29 @@ def load_model( yaml, device, dtype, attention ):
raise gr.Error(e) raise gr.Error(e)
gr.Info(f"Loaded model") gr.Info(f"Loaded model")
def get_speakers():
return cfg.dataset.training
#@gradio_wrapper(inputs=layout["dataset"]["inputs"].keys())
def load_sample( speaker ):
metadata_path = cfg.metadata_dir / f'{speaker}.json'
metadata = json_read( metadata_path )
if not metadata:
raise gr.Error(f"Metadata not found: {metadata_path}")
key = random.choice( list(metadata.keys()) )
path = cfg.data_dir / speaker / f'{key}.enc' # to-do: get proper file extension
data = json_stringify( metadata[key], pretty=True )
wav, sr = None, None
if path.exists():
artifact = np.load(path, allow_pickle=True)[()]
codes = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=cfg.device)
wav, sr = decode_to_wave( codes )
wav = wav.squeeze(0).cpu().numpy()
return data, (sr, wav)
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention="auto"): def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention="auto"):
global tts global tts
@ -429,6 +459,21 @@ with ui:
) )
""" """
with gr.Tab("Dataset"):
with gr.Row():
with gr.Column(scale=7):
layout["dataset"]["outputs"]["transcription"] = gr.Textbox(lines=5, label="Sample Metadata")
with gr.Column(scale=1):
layout["dataset"]["inputs"]["speaker"] = gr.Dropdown(choices=get_speakers(), label="Speakers")
layout["dataset"]["outputs"]["audio"] = gr.Audio(label="Output")
layout["dataset"]["buttons"]["sample"] = gr.Button(value="Sample")
layout["dataset"]["buttons"]["sample"].click(
fn=load_sample,
inputs=[ x for x in layout["dataset"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None],
)
with gr.Tab("Settings"): with gr.Tab("Settings"):
with gr.Row(): with gr.Row():
with gr.Column(scale=7): with gr.Column(scale=7):