From d5c1433268ac73b432d3d4908eb96867286748fe Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 18 Feb 2023 02:07:22 +0000 Subject: [PATCH] a bit of UI cleanup, import multiple audio files at once, actually shows progress when importing voices, hides audio metadata / latents if no generated settings are detected, preparing datasets shows its progress, saving a training YAML shows a message when done, training now works within the web UI, training output shows to web UI, provided notebook is cleaned up and uses a venv, etc. --- {training => models}/.template.yaml | 0 notebook.ipynb | 56 ++++-- src/train.py | 59 +++--- src/utils.py | 169 ++++++++++------ src/webui.py | 291 +++++++++++++--------------- train.bat | 4 + train.sh | 3 + training/.gitkeep | 0 8 files changed, 323 insertions(+), 259 deletions(-) rename {training => models}/.template.yaml (100%) create mode 100755 train.bat create mode 100755 train.sh create mode 100755 training/.gitkeep diff --git a/training/.template.yaml b/models/.template.yaml similarity index 100% rename from training/.template.yaml rename to models/.template.yaml diff --git a/notebook.ipynb b/notebook.ipynb index 9257f23..b5de8af 100755 --- a/notebook.ipynb +++ b/notebook.ipynb @@ -3,10 +3,7 @@ "nbformat_minor":0, "metadata":{ "colab":{ - "private_outputs":true, - "provenance":[ - - ] + "private_outputs":true }, "kernelspec":{ "name":"python3", @@ -40,41 +37,62 @@ "source":[ "!git clone https://git.ecker.tech/mrq/ai-voice-cloning/\n", "%cd ai-voice-cloning\n", + "!apt install python3.8-venv\n", + "!python -m venv venv\n", + "!source ./venv/bin/activate\n", + "!git clone https://git.ecker.tech/mrq/DL-Art-School dlas\n", "!python -m pip install --upgrade pip\n", "!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116\n", - "!python -m pip install -r ./requirements.txt\n", - "!git clone https://git.ecker.tech/mrq/DL-Art-School dlas\n", - "!python -m pip install -r ./dlas/requirements.txt" + "!python -m pip install -r ./dlas/requirements.txt\n", + "!python -m pip install -r ./requirements.txt" ] }, { "cell_type":"markdown", "source":[ - "# Restart Runtime Before Proceeding" + "# Update Repos" ], "metadata":{ - "id":"TXFyLVLA48S5" + "id":"IzrGt5IcHlAD" } }, { "cell_type":"code", "source":[ - "# colab requires the runtime to restart before use\n", - "exit()" + "%cd /content/ai-voice-cloning/dlas\n", + "!git reset --hard HEAD\n", + "!git pull\n", + "%cd ..\n", + "!git reset --hard HEAD\n", + "!git pull\n", + "!python -m pip install ffmpeg ffmpeg-python" ], "metadata":{ - "id":"FVUOtSASCSJ8" + "id":"3DktoOXSHmtw" }, "execution_count":null, "outputs":[ ] }, + { + "cell_type":"markdown", + "source":[ + "# Mount Drive" + ], + "metadata":{ + "id":"2Y4t9zDIZMTg" + } + }, { "cell_type":"code", "source":[ "from google.colab import drive\n", - "drive.mount('/content/drive')" + "drive.mount('/content/drive')\n", + "\n", + "%cd /content/ai-voice-cloning\n", + "!rm -r ./training\n", + "!ln -s /content/drive/MyDrive/training/" ], "metadata":{ "id":"SGt9gyvubveT" @@ -97,6 +115,8 @@ "cell_type":"code", "source":[ "%cd /content/ai-voice-cloning\n", + "!python -m venv venv\n", + "!source ./venv/bin/activate\n", "\n", "import os\n", "import sys\n", @@ -117,7 +137,7 @@ "\n", "webui = setup_gradio()\n", "tts = setup_tortoise()\n", - "webui.launch(share=True, prevent_thread_lock=True, debug=True, height=1000)\n", + "webui.launch(share=True, prevent_thread_lock=True, height=1000)\n", "webui.block_thread()" ], "metadata":{ @@ -140,8 +160,9 @@ { "cell_type":"code", "source":[ + "# This is in case you can't get training through the web UI\n", "%cd /content/ai-voice-cloning\n", - "!python ./src/train.py -opt ./training/finetune.yaml" + "!python ./dlas/codes/train.py -opt ./training/finetune.yaml" ], "metadata":{ "id":"-KayB8klA5tY" @@ -167,8 +188,9 @@ "!apt install -y p7zip-full\n", "from datetime import datetime\n", "timestamp = datetime.now().strftime('%m-%d-%Y_%H:%M:%S')\n", - "!mkdir -p \"../{timestamp}\"\n", - "!mv ./results/* \"../{timestamp}/.\"\n", + "!mkdir -p \"../{timestamp}/results\"\n", + "!mv ./results/* \"../{timestamp}/results/.\"\n", + "!mv ./training/* \"../{timestamp}/training/.\"\n", "!7z a -t7z -m0=lzma2 -mx=9 -mfb=64 -md=32m -ms=on \"../{timestamp}.7z\" \"../{timestamp}/\"\n", "!ls ~/\n", "!echo \"Finished zipping, archive is available at {timestamp}.7z\"" diff --git a/src/train.py b/src/train.py index c10fe64..80b243a 100755 --- a/src/train.py +++ b/src/train.py @@ -25,32 +25,37 @@ from utils import util, options as option # this is effectively just copy pasted and cleaned up from the __main__ section of training.py # I'll clean it up better -parser = argparse.ArgumentParser() -parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml') -parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') -args = parser.parse_args() -opt = option.parse(args.opt, is_train=True) -if args.launcher != 'none': - # export CUDA_VISIBLE_DEVICES for running in distributed mode. - if 'gpu_ids' in opt.keys(): - gpu_list = ','.join(str(x) for x in opt['gpu_ids']) - os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list - print('export CUDA_VISIBLE_DEVICES=' + gpu_list) -trainer = tr.Trainer() +def train(yaml, launcher='none'): + opt = option.parse(yaml, is_train=True) + if launcher != 'none': + # export CUDA_VISIBLE_DEVICES for running in distributed mode. + if 'gpu_ids' in opt.keys(): + gpu_list = ','.join(str(x) for x in opt['gpu_ids']) + os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list + print('export CUDA_VISIBLE_DEVICES=' + gpu_list) + trainer = tr.Trainer() -#### distributed training settings -if args.launcher == 'none': # disabled distributed training - opt['dist'] = False - trainer.rank = -1 - if len(opt['gpu_ids']) == 1: - torch.cuda.set_device(opt['gpu_ids'][0]) - print('Disabled distributed training.') -else: - opt['dist'] = True - init_dist('nccl') - trainer.world_size = torch.distributed.get_world_size() - trainer.rank = torch.distributed.get_rank() - torch.cuda.set_device(torch.distributed.get_rank()) + #### distributed training settings + if launcher == 'none': # disabled distributed training + opt['dist'] = False + trainer.rank = -1 + if len(opt['gpu_ids']) == 1: + torch.cuda.set_device(opt['gpu_ids'][0]) + print('Disabled distributed training.') + else: + opt['dist'] = True + init_dist('nccl') + trainer.world_size = torch.distributed.get_world_size() + trainer.rank = torch.distributed.get_rank() + torch.cuda.set_device(torch.distributed.get_rank()) -trainer.init(args.opt, opt, args.launcher) -trainer.do_training() \ No newline at end of file + trainer.init(yaml, opt, launcher) + trainer.do_training() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml') + parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') + args = parser.parse_args() + + train(args.opt, args.launcher) \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index 5015a10..b391075 100755 --- a/src/utils.py +++ b/src/utils.py @@ -1,5 +1,4 @@ import os - if 'XDG_CACHE_HOME' not in os.environ: os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/')) @@ -15,7 +14,9 @@ import json import base64 import re import urllib.request +import signal +import tqdm import torch import torchaudio import music_tag @@ -90,6 +91,8 @@ def setup_args(): parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once") parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)") parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output") + + parser.add_argument("--os", default="unix", help="Specifies which OS, easily") args = parser.parse_args() args.embed_output_metadata = not args.no_embed_output_metadata @@ -427,20 +430,37 @@ def generate( import subprocess +training_process = None def run_training(config_path): print("Unloading TTS to save VRAM.") global tts del tts tts = None - cmd = ["python", "./src/train.py", "-opt", config_path] + global training_process + torch.multiprocessing.freeze_support() + cmd = [f'train.{"bat" if args.os == "windows" else "sh"}', config_path] print("Spawning process: ", " ".join(cmd)) - subprocess.run(cmd, env=os.environ.copy(), shell=True) - """ - from train import train - train(config) - """ + training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) + buffer=[] + for line in iter(training_process.stdout.readline, ""): + buffer.append(line) + yield "".join(buffer) + + training_process.stdout.close() + return_code = training_process.wait() + training_process = None + if return_code: + raise subprocess.CalledProcessError(return_code, cmd) + + +def stop_training(): + if training_process is None: + return "No training in progress" + training_process.kill() + training_process = None + return "Training cancelled" def setup_voicefixer(restart=False): global voicefixer @@ -485,19 +505,23 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None "validation_name": validation_name if validation_name else "finetune", "validation_path": validation_path if validation_path else "./training/finetune/train.txt", } + outfile = f'./training/{settings["name"]}.yaml' - with open(f'./training/.template.yaml', 'r', encoding="utf-8") as f: + with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f: yaml = f.read() for k in settings: yaml = yaml.replace(f"${{{k}}}", str(settings[k])) - - with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f: + + with open(outfile, 'w', encoding="utf-8") as f: f.write(yaml) -def prepare_dataset( files, outdir, language=None ): + return f"Training settings saved to: {outfile}" + +def prepare_dataset( files, outdir, language=None, progress=None ): global whisper_model if whisper_model is None: + notify_progress(f"Loading Whisper model: {args.whisper_model}", progress) whisper_model = whisper.load_model(args.whisper_model) os.makedirs(outdir, exist_ok=True) @@ -506,7 +530,7 @@ def prepare_dataset( files, outdir, language=None ): results = {} transcription = [] - for file in files: + for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): print(f"Transcribing file: {file}") result = whisper_model.transcribe(file, language=language if language else "English") @@ -517,7 +541,7 @@ def prepare_dataset( files, outdir, language=None ): waveform, sampling_rate = torchaudio.load(file) num_channels, num_frames = waveform.shape - for segment in result['segments']: + for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): start = int(segment['start'] * sampling_rate) end = int(segment['end'] * sampling_rate) @@ -535,66 +559,74 @@ def prepare_dataset( files, outdir, language=None ): with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f: f.write("\n".join(transcription)) + return f"Processed dataset to: {outdir}" + def reset_generation_settings(): with open(f'./config/generate.json', 'w', encoding="utf-8") as f: f.write(json.dumps({}, indent='\t') ) return import_generate_settings() -def import_voice(file, saveAs = None): +def import_voices(files, saveAs=None, progress=None): global args - j, latents = read_generate_settings(file, read_latents=True) - - if j is not None and saveAs is None: - saveAs = j['voice'] - if saveAs is None or saveAs == "": - raise Exception("Specify a voice name") + if not isinstance(files, list): + files = [files] - outdir = f'{get_voice_dir()}/{saveAs}/' - os.makedirs(outdir, exist_ok=True) - if latents: - with open(f'{outdir}/cond_latents.pth', 'wb') as f: - f.write(latents) - latents = f'{outdir}/cond_latents.pth' - print(f"Imported latents to {latents}") - else: - filename = file.name - if filename[-4:] != ".wav": - raise Exception("Please convert to a WAV first") + for file in enumerate_progress(files, desc="Importing voice files", progress=progress): + j, latents = read_generate_settings(file, read_latents=True) + + if j is not None and saveAs is None: + saveAs = j['voice'] + if saveAs is None or saveAs == "": + raise Exception("Specify a voice name") - path = f"{outdir}/{os.path.basename(filename)}" - waveform, sampling_rate = torchaudio.load(filename) + outdir = f'{get_voice_dir()}/{saveAs}/' + os.makedirs(outdir, exist_ok=True) - if args.voice_fixer and voicefixer is not None: - # resample to best bandwidth since voicefixer will do it anyways through librosa - if sampling_rate != 44100: - print(f"Resampling imported voice sample: {path}") - resampler = torchaudio.transforms.Resample( - sampling_rate, - 44100, - lowpass_filter_width=16, - rolloff=0.85, - resampling_method="kaiser_window", - beta=8.555504641634386, - ) - waveform = resampler(waveform) - sampling_rate = 44100 - - torchaudio.save(path, waveform, sampling_rate) - - print(f"Running 'voicefixer' on voice sample: {path}") - voicefixer.restore( - input = path, - output = path, - cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda, - #mode=mode, - ) + if latents: + print(f"Importing latents to {latents}") + with open(f'{outdir}/cond_latents.pth', 'wb') as f: + f.write(latents) + latents = f'{outdir}/cond_latents.pth' + print(f"Imported latents to {latents}") else: - torchaudio.save(path, waveform, sampling_rate) + filename = file.name + if filename[-4:] != ".wav": + raise Exception("Please convert to a WAV first") + path = f"{outdir}/{os.path.basename(filename)}" + print(f"Importing voice to {path}") - print(f"Imported voice to {path}") + waveform, sampling_rate = torchaudio.load(filename) + if args.voice_fixer and voicefixer is not None: + # resample to best bandwidth since voicefixer will do it anyways through librosa + if sampling_rate != 44100: + print(f"Resampling imported voice sample: {path}") + resampler = torchaudio.transforms.Resample( + sampling_rate, + 44100, + lowpass_filter_width=16, + rolloff=0.85, + resampling_method="kaiser_window", + beta=8.555504641634386, + ) + waveform = resampler(waveform) + sampling_rate = 44100 + + torchaudio.save(path, waveform, sampling_rate) + + print(f"Running 'voicefixer' on voice sample: {path}") + voicefixer.restore( + input = path, + output = path, + cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda, + #mode=mode, + ) + else: + torchaudio.save(path, waveform, sampling_rate) + + print(f"Imported voice to {path}") def import_generate_settings(file="./config/generate.json"): settings, _ = read_generate_settings(file, read_latents=False) @@ -759,4 +791,21 @@ def read_generate_settings(file, read_latents=True, read_json=True): return ( j, latents, - ) \ No newline at end of file + ) + +def enumerate_progress(iterable, desc=None, progress=None, verbose=None): + if verbose and desc is not None: + print(desc) + + if progress is None: + return tqdm(iterable, disable=not verbose) + return progress.tqdm(iterable, desc=f'{progress.msg_prefix} {desc}' if hasattr(progress, 'msg_prefix') else desc, track_tqdm=True) + +def notify_progress(message, progress=None, verbose=True): + if verbose: + print(message) + + if progress is None: + return + + progress(0, desc=message) \ No newline at end of file diff --git a/src/webui.py b/src/webui.py index 2101f85..c56e38c 100755 --- a/src/webui.py +++ b/src/webui.py @@ -135,6 +135,21 @@ def get_training_configs(): def update_training_configs(): return gr.update(choices=get_training_configs()) +history_headers = { + "Name": "", + "Samples": "num_autoregressive_samples", + "Iterations": "diffusion_iterations", + "Temp.": "temperature", + "Sampler": "diffusion_sampler", + "CVVP": "cvvp_weight", + "Top P": "top_p", + "Diff. Temp.": "diffusion_temperature", + "Len Pen": "length_penalty", + "Rep Pen": "repetition_penalty", + "Cond-Free K": "cond_free_k", + "Time": "time", +} + def history_view_results( voice ): results = [] files = [] @@ -148,7 +163,7 @@ def history_view_results( voice ): continue values = [] - for k in headers: + for k in history_headers: v = file if k != "Name": v = metadata[headers[k]] @@ -163,6 +178,10 @@ def history_view_results( voice ): gr.Dropdown.update(choices=sorted(files)) ) +def import_voices_proxy(files, name, progress=gr.Progress(track_tqdm=True)): + import_voices(files, name, progress) + return gr.update() + def read_generate_settings_proxy(file, saveAs='.temp'): j, latents = read_generate_settings(file) @@ -175,13 +194,14 @@ def read_generate_settings_proxy(file, saveAs='.temp'): latents = f'{outdir}/cond_latents.pth' return ( - j, + gr.update(value=j, visible=j is not None), + gr.update(visible=j is not None), gr.update(value=latents, visible=latents is not None), None if j is None else j['voice'] ) -def prepare_dataset_proxy( voice, language ): - return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language ) +def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ): + return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress ) def update_voices(): return ( @@ -222,52 +242,18 @@ def setup_gradio(): with gr.Column(): delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n") - emotion = gr.Radio( - ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], - value="Custom", - label="Emotion", - type="value", - interactive=True - ) + emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True ) prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)") - voice = gr.Dropdown( - get_voice_list(), - label="Voice", - type="value", - ) - mic_audio = gr.Audio( - label="Microphone Source", - source="microphone", - type="filepath", - ) + voice = gr.Dropdown(get_voice_list(), label="Voice", type="value") + mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" ) refresh_voices = gr.Button(value="Refresh Voice List") voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1) recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents") - recompute_voice_latents.click(compute_latents, - inputs=[ - voice, - voice_latents_chunks, - ], - outputs=voice, - ) - - prompt.change(fn=lambda value: gr.update(value="Custom"), - inputs=prompt, - outputs=emotion - ) - mic_audio.change(fn=lambda value: gr.update(value="microphone"), - inputs=mic_audio, - outputs=voice - ) with gr.Column(): candidates = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates") seed = gr.Number(value=0, precision=0, label="Seed") - preset = gr.Radio( - ["Ultra Fast", "Fast", "Standard", "High Quality"], - label="Preset", - type="value", - ) + preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value" ) num_autoregressive_samples = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Samples") diffusion_iterations = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Iterations") @@ -275,19 +261,7 @@ def setup_gradio(): breathing_room = gr.Slider(value=8, minimum=1, maximum=32, step=1, label="Pause Size") diffusion_sampler = gr.Radio( ["P", "DDIM"], # + ["K_Euler_A", "DPM++2M"], - value="P", - label="Diffusion Samplers", - type="value", - ) - - preset.change(fn=update_presets, - inputs=preset, - outputs=[ - num_autoregressive_samples, - diffusion_iterations, - ], - ) - + value="P", label="Diffusion Samplers", type="value" ) show_experimental_settings = gr.Checkbox(label="Show Experimental Settings") reset_generation_settings_button = gr.Button(value="Reset to Default") with gr.Column(visible=False) as col: @@ -300,12 +274,6 @@ def setup_gradio(): length_penalty = gr.Slider(value=1.0, minimum=0, maximum=8, label="Length Penalty") repetition_penalty = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty") cond_free_k = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K") - - show_experimental_settings.change( - fn=lambda x: gr.update(visible=x), - inputs=show_experimental_settings, - outputs=experimental_column - ) with gr.Column(): submit = gr.Button(value="Generate") stop = gr.Button(value="Stop") @@ -315,33 +283,13 @@ def setup_gradio(): output_audio = gr.Audio(label="Output") candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False) output_pick = gr.Button(value="Select Candidate", visible=False) - with gr.Tab("History"): with gr.Row(): with gr.Column(): - headers = { - "Name": "", - "Samples": "num_autoregressive_samples", - "Iterations": "diffusion_iterations", - "Temp.": "temperature", - "Sampler": "diffusion_sampler", - "CVVP": "cvvp_weight", - "Top P": "top_p", - "Diff. Temp.": "diffusion_temperature", - "Len Pen": "length_penalty", - "Rep Pen": "repetition_penalty", - "Cond-Free K": "cond_free_k", - "Time": "time", - } - history_info = gr.Dataframe(label="Results", headers=list(headers.keys())) + history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys())) with gr.Row(): with gr.Column(): - history_voices = gr.Dropdown( - get_voice_list("./results/"), - label="Voice", - type="value", - ) - + history_voices = gr.Dropdown(choices=get_voice_list("./results/"), label="Voice", type="value") history_view_results_button = gr.Button(value="View Files") with gr.Column(): history_results_list = gr.Dropdown(label="Results",type="value", interactive=True) @@ -349,51 +297,16 @@ def setup_gradio(): with gr.Column(): history_audio = gr.Audio() history_copy_settings_button = gr.Button(value="Copy Settings") - - history_view_results_button.click( - fn=history_view_results, - inputs=history_voices, - outputs=[ - history_info, - history_results_list, - ] - ) - history_view_result_button.click( - fn=lambda voice, file: f"./results/{voice}/{file}", - inputs=[ - history_voices, - history_results_list, - ], - outputs=history_audio - ) with gr.Tab("Utilities"): with gr.Row(): with gr.Column(): - audio_in = gr.File(type="file", label="Audio Input", file_types=["audio"]) - copy_button = gr.Button(value="Copy Settings") + audio_in = gr.Files(type="file", label="Audio Input", file_types=["audio"]) import_voice_name = gr.Textbox(label="Voice Name") import_voice_button = gr.Button(value="Import Voice") with gr.Column(): - metadata_out = gr.JSON(label="Audio Metadata") - latents_out = gr.File(type="binary", label="Voice Latents") - - audio_in.upload( - fn=read_generate_settings_proxy, - inputs=audio_in, - outputs=[ - metadata_out, - latents_out, - import_voice_name - ] - ) - - import_voice_button.click( - fn=import_voice, - inputs=[ - audio_in, - import_voice_name, - ] - ) + metadata_out = gr.JSON(label="Audio Metadata", visible=False) + copy_button = gr.Button(value="Copy Settings", visible=False) + latents_out = gr.File(type="binary", label="Voice Latents", visible=False) with gr.Tab("Training"): with gr.Tab("Prepare Dataset"): with gr.Row(): @@ -402,16 +315,9 @@ def setup_gradio(): gr.Dropdown( get_voice_list(), label="Dataset Source", type="value" ), gr.Textbox(label="Language", placeholder="English") ] - dataset_voices = dataset_settings[0] - - with gr.Column(): prepare_dataset_button = gr.Button(value="Prepare") - - prepare_dataset_button.click( - prepare_dataset_proxy, - inputs=dataset_settings, - outputs=None - ) + with gr.Column(): + prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) with gr.Tab("Generate Configuration"): with gr.Row(): with gr.Column(): @@ -421,8 +327,6 @@ def setup_gradio(): gr.Number(label="Print Frequency", value=50), gr.Number(label="Save Frequency", value=50), ] - save_yaml_button = gr.Button(value="Save Training Configuration") - with gr.Column(): training_settings = training_settings + [ gr.Textbox(label="Training Name", placeholder="finetune"), gr.Textbox(label="Dataset Name", placeholder="finetune"), @@ -430,24 +334,18 @@ def setup_gradio(): gr.Textbox(label="Validation Name", placeholder="finetune"), gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"), ] - - save_yaml_button.click(save_training_settings, - inputs=training_settings, - outputs=None - ) - with gr.Tab("Train"): + with gr.Column(): + save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) + save_yaml_button = gr.Button(value="Save Training Configuration") + with gr.Tab("Run Training"): with gr.Row(): with gr.Column(): training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs()) refresh_configs = gr.Button(value="Refresh Configurations") - train = gr.Button(value="Train") - - refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs) - train.click(run_training, - inputs=training_configs, - outputs=None - ) - + start_training_button = gr.Button(value="Train") + stop_training_button = gr.Button(value="Stop") + with gr.Column(): + training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) with gr.Tab("Settings"): with gr.Row(): exec_inputs = [] @@ -465,23 +363,22 @@ def setup_gradio(): gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents), gr.Checkbox(label="Defer TTS Load", value=args.defer_tts_load), gr.Textbox(label="Device Override", value=args.device_override), - gr.Dropdown(label="Whisper Model", value=args.whisper_model, choices=["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"]), ] - gr.Button(value="Check for Updates").click(check_for_updates) - gr.Button(value="Reload TTS").click(reload_tts) with gr.Column(): exec_inputs = exec_inputs + [ gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size), gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count), gr.Number(label="Ouptut Sample Rate", precision=0, value=args.output_sample_rate), gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume), + gr.Dropdown(label="Whisper Model", value=args.whisper_model, choices=["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"]), ] + gr.Button(value="Check for Updates").click(check_for_updates) + gr.Button(value="Reload TTS").click(reload_tts) for i in exec_inputs: - i.change( - fn=export_exec_settings, - inputs=exec_inputs - ) + i.change( fn=export_exec_settings, inputs=exec_inputs ) + + # console_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) input_settings = [ text, @@ -507,11 +404,76 @@ def setup_gradio(): experimental_checkboxes, ] + history_view_results_button.click( + fn=history_view_results, + inputs=history_voices, + outputs=[ + history_info, + history_results_list, + ] + ) + history_view_result_button.click( + fn=lambda voice, file: f"./results/{voice}/{file}", + inputs=[ + history_voices, + history_results_list, + ], + outputs=history_audio + ) + audio_in.upload( + fn=read_generate_settings_proxy, + inputs=audio_in, + outputs=[ + metadata_out, + copy_button, + latents_out, + import_voice_name + ] + ) + + import_voice_button.click( + fn=import_voices_proxy, + inputs=[ + audio_in, + import_voice_name, + ], + outputs=import_voice_name #console_output + ) + show_experimental_settings.change( + fn=lambda x: gr.update(visible=x), + inputs=show_experimental_settings, + outputs=experimental_column + ) + preset.change(fn=update_presets, + inputs=preset, + outputs=[ + num_autoregressive_samples, + diffusion_iterations, + ], + ) + + recompute_voice_latents.click(compute_latents, + inputs=[ + voice, + voice_latents_chunks, + ], + outputs=voice, + ) + + prompt.change(fn=lambda value: gr.update(value="Custom"), + inputs=prompt, + outputs=emotion + ) + mic_audio.change(fn=lambda value: gr.update(value="microphone"), + inputs=mic_audio, + outputs=voice + ) + refresh_voices.click(update_voices, inputs=None, outputs=[ voice, - dataset_voices, + dataset_settings[0], history_voices ] ) @@ -552,6 +514,25 @@ def setup_gradio(): outputs=input_settings ) + refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs) + start_training_button.click(run_training, + inputs=training_configs, + outputs=training_output #console_output + ) + stop_training_button.click(stop_training, + inputs=None, + outputs=training_output #console_output + ) + prepare_dataset_button.click( + prepare_dataset_proxy, + inputs=dataset_settings, + outputs=prepare_dataset_output #console_output + ) + save_yaml_button.click(save_training_settings, + inputs=training_settings, + outputs=save_yaml_output #console_output + ) + if os.path.isfile('./config/generate.json'): ui.load(import_generate_settings, inputs=None, outputs=input_settings) diff --git a/train.bat b/train.bat new file mode 100755 index 0000000..55d682a --- /dev/null +++ b/train.bat @@ -0,0 +1,4 @@ +call .\venv\Scripts\activate.bat +python ./src/train.py -opt "%1" +deactivate +pause \ No newline at end of file diff --git a/train.sh b/train.sh new file mode 100755 index 0000000..2840762 --- /dev/null +++ b/train.sh @@ -0,0 +1,3 @@ +source ./venv/bin/activate +python3 ./src/train.py -opt "$1" +deactivate diff --git a/training/.gitkeep b/training/.gitkeep new file mode 100755 index 0000000..e69de29