oops, turns out these are not split by speaker names already........ (also added sampling the dataset in the webui for easy viewing)

This commit is contained in:
mrq 2024-09-18 20:19:46 -05:00
parent fa9d3f6c06
commit b5bec0c9ce
3 changed files with 66 additions and 17 deletions

View File

@ -100,29 +100,29 @@ def process(
group_name = "Emilia"
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{language}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {language}"):
if not os.path.isdir(f'./{input_audio}/{language}/{speaker_id}'):
print("Is not dir:", f'./{input_audio}/{language}/{speaker_id}')
for speaker_group in tqdm(process_items(os.listdir(f'./{input_audio}/{language}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {language}"):
if not os.path.isdir(f'./{input_audio}/{language}/{speaker_group}'):
print("Is not dir:", f'./{input_audio}/{language}/{speaker_group}')
continue
if speaker_id in ignore_speakers:
if speaker_group in ignore_speakers:
continue
if only_speakers and speaker_id not in only_speakers:
if only_speakers and speaker_group not in only_speakers:
continue
os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True)
os.makedirs(f'./{output_dataset}/{group_name}/{speaker_group}/', exist_ok=True)
if f'{group_name}/{speaker_id}' not in dataset:
dataset.append(f'{group_name}/{speaker_id}')
if f'{group_name}/{speaker_group}' not in dataset:
dataset.append(f'{group_name}/{speaker_group}')
txts = []
wavs = []
for filename in os.listdir(f'./{input_audio}/{language}/{speaker_id}'):
for filename in os.listdir(f'./{input_audio}/{language}/{speaker_group}'):
if ".mp3" not in filename:
continue
inpath = Path(f'./{input_audio}/{language}/{speaker_id}/{filename}')
inpath = Path(f'./{input_audio}/{language}/{speaker_group}/{filename}')
jsonpath = _replace_file_extension(inpath, ".json")
if not inpath.exists() or not jsonpath.exists():
missing["audio"].append(str(inpath))
@ -130,15 +130,15 @@ def process(
extension = os.path.splitext(filename)[-1][1:]
fname = filename.replace(f'.{extension}', "")
waveform, sample_rate = None, None
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
metadata = json.load(open(jsonpath, "r", encoding="utf-8"))
if "text" not in metadata:
continue
waveform, sample_rate = None, None
metadata = json.load(open(jsonpath, "r", encoding="utf-8"))
speaker_id = metadata["speaker"]
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
if _replace_file_extension(outpath, audio_extension).exists():
continue

View File

@ -13,9 +13,13 @@ except:
from .utils import truncate_json
def json_stringify( data, truncate=False ):
def json_stringify( data, truncate=False, pretty=False ):
if truncate:
return truncate_json( json.dumps( data ) )
if pretty:
if use_orjson:
return json.dumps( data, option=json.OPT_INDENT_2 ).decode('utf-8')
return json.dumps( data, indent='\t' ).decode('utf-8')
return json.dumps( data )
def json_parse( string ):

View File

@ -5,6 +5,10 @@ import argparse
import random
import tempfile
import functools
import torch
import numpy as np
from datetime import datetime
import torchaudio
@ -16,6 +20,8 @@ from pathlib import Path
from .inference import TTS, cfg
from .train import train
from .utils import get_devices, setup_logging
from .utils.io import json_read, json_stringify
from .emb.qnt import decode_to_wave
tts = None
@ -23,6 +29,7 @@ layout = {}
layout["inference_tts"] = {}
layout["inference_stt"] = {}
layout["training"] = {}
layout["dataset"] = {}
layout["settings"] = {}
for k in layout.keys():
@ -90,6 +97,29 @@ def load_model( yaml, device, dtype, attention ):
raise gr.Error(e)
gr.Info(f"Loaded model")
def get_speakers():
return cfg.dataset.training
#@gradio_wrapper(inputs=layout["dataset"]["inputs"].keys())
def load_sample( speaker ):
metadata_path = cfg.metadata_dir / f'{speaker}.json'
metadata = json_read( metadata_path )
if not metadata:
raise gr.Error(f"Metadata not found: {metadata_path}")
key = random.choice( list(metadata.keys()) )
path = cfg.data_dir / speaker / f'{key}.enc' # to-do: get proper file extension
data = json_stringify( metadata[key], pretty=True )
wav, sr = None, None
if path.exists():
artifact = np.load(path, allow_pickle=True)[()]
codes = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=cfg.device)
wav, sr = decode_to_wave( codes )
wav = wav.squeeze(0).cpu().numpy()
return data, (sr, wav)
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention="auto"):
global tts
@ -429,6 +459,21 @@ with ui:
)
"""
with gr.Tab("Dataset"):
with gr.Row():
with gr.Column(scale=7):
layout["dataset"]["outputs"]["transcription"] = gr.Textbox(lines=5, label="Sample Metadata")
with gr.Column(scale=1):
layout["dataset"]["inputs"]["speaker"] = gr.Dropdown(choices=get_speakers(), label="Speakers")
layout["dataset"]["outputs"]["audio"] = gr.Audio(label="Output")
layout["dataset"]["buttons"]["sample"] = gr.Button(value="Sample")
layout["dataset"]["buttons"]["sample"].click(
fn=load_sample,
inputs=[ x for x in layout["dataset"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None],
)
with gr.Tab("Settings"):
with gr.Row():
with gr.Column(scale=7):