From bcec64af0fb188511ddd9412519dc08cb4dd486a Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 17 Feb 2023 19:06:05 +0000 Subject: [PATCH] cleanup, "injected" dvae.pth to download through tortoise's model loader, so I don't need to keep copying it --- src/train.py | 15 +++++ src/utils.py | 53 +++++++++++++----- src/webui.py | 155 ++++++++++++++++++++++++--------------------------- 3 files changed, 125 insertions(+), 98 deletions(-) diff --git a/src/train.py b/src/train.py index 17d2617..c10fe64 100755 --- a/src/train.py +++ b/src/train.py @@ -4,12 +4,27 @@ import argparse import os import sys +# this is some massive kludge that only works if it's called from a shell and not an import/PIP package +# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell + sys.path.insert(0, './dlas/codes/') +# this is also because DLAS is not written as a package in mind +# it'll gripe when it wants to import from train.py sys.path.insert(0, './dlas/') +# for PIP, replace it with: +# sys.path.insert(0, os.path.dirname(os.path.realpath(dlas.__file__))) +# sys.path.insert(0, f"{os.path.dirname(os.path.realpath(dlas.__file__))}/../") + +# don't even really bother trying to get DLAS PIP'd +# without kludge, it'll have to be accessible as `codes` and not `dlas` + from codes import train as tr from utils import util, options as option +# this is effectively just copy pasted and cleaned up from the __main__ section of training.py +# I'll clean it up better + parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') diff --git a/src/utils.py b/src/utils.py index e7601b0..3545fcf 100755 --- a/src/utils.py +++ b/src/utils.py @@ -24,18 +24,20 @@ import gradio.utils from datetime import datetime -from tortoise.api import TextToSpeech +from tortoise.api import TextToSpeech, MODELS, get_model_path from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir from tortoise.utils.text import split_and_recombine_text from tortoise.utils.device import get_device_name, set_device_name import whisper +MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" + args = None tts = None webui = None voicefixer = None -dlas = None +whisper_model = None def get_args(): global args @@ -53,7 +55,7 @@ def setup_args(): 'sample-batch-size': None, 'embed-output-metadata': True, 'latents-lean-and-mean': True, - 'voice-fixer': False, # I'm tired of long initialization of Colab notebooks + 'voice-fixer': True, 'voice-fixer-use-cuda': True, 'force-cpu-for-conditioning-latents': False, 'device-override': None, @@ -420,22 +422,46 @@ def generate( stats, ) +def run_training(config_path): + global tts + del tts + tts = None + + import subprocess + subprocess.run(["python", "./src/train.py", "-opt", config_path], env=os.environ.copy(), shell=True, stdout=subprocess.PIPE) + """ + from train import train + train(config) + """ + +def setup_voicefixer(restart=False): + global voicefixer + if restart: + del voicefixer + voicefixer = None + + try: + print("Initializating voice-fixer") + from voicefixer import VoiceFixer + voicefixer = VoiceFixer() + print("initialized voice-fixer") + except Exception as e: + print(f"Error occurred while tring to initialize voicefixer: {e}") + def setup_tortoise(restart=False): global args global tts - global voicefixer if args.voice_fixer and not restart: - try: - from voicefixer import VoiceFixer - print("Initializating voice-fixer") - voicefixer = VoiceFixer() - print("initialized voice-fixer") - except Exception as e: - print(f"Error occurred while tring to initialize voicefixer: {e}") + setup_voicefixer(restart=restart) + + if restart: + del tts + tts = None print("Initializating TorToiSe...") tts = TextToSpeech(minor_optimizations=not args.low_vram) + get_model_path('dvae.pth') print("TorToiSe initialized, ready for generation.") return tts @@ -461,7 +487,6 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f: f.write(yaml) -whisper_model = None def prepare_dataset( files, outdir, language=None ): global whisper_model if whisper_model is None: @@ -641,9 +666,7 @@ def check_for_updates(): return False def reload_tts(): - global tts - del tts - tts = setup_tortoise(restart=True) + setup_tortoise(restart=True) def cancel_generate(): tortoise.api.STOP_SIGNAL = True diff --git a/src/webui.py b/src/webui.py index 4f4a103..9fa6738 100755 --- a/src/webui.py +++ b/src/webui.py @@ -123,6 +123,76 @@ def update_presets(value): else: return (gr.update(), gr.update()) +def get_training_configs(): + configs = [] + for i, file in enumerate(sorted(os.listdir(f"./training/"))): + if file[-5:] != ".yaml" or file[0] == ".": + continue + configs.append(f"./training/{file}") + + return configs + +def update_training_configs(): + return gr.update(choices=get_training_configs()) + +def history_view_results( voice ): + results = [] + files = [] + outdir = f"./results/{voice}/" + for i, file in enumerate(sorted(os.listdir(outdir))): + if file[-4:] != ".wav": + continue + + metadata, _ = read_generate_settings(f"{outdir}/{file}", read_latents=False) + if metadata is None: + continue + + values = [] + for k in headers: + v = file + if k != "Name": + v = metadata[headers[k]] + values.append(v) + + + files.append(file) + results.append(values) + + return ( + results, + gr.Dropdown.update(choices=sorted(files)) + ) + +def read_generate_settings_proxy(file, saveAs='.temp'): + j, latents = read_generate_settings(file) + + if latents: + outdir = f'{get_voice_dir()}/{saveAs}/' + os.makedirs(outdir, exist_ok=True) + with open(f'{outdir}/cond_latents.pth', 'wb') as f: + f.write(latents) + + latents = f'{outdir}/cond_latents.pth' + + return ( + j, + gr.update(value=latents, visible=latents is not None), + None if j is None else j['voice'] + ) + +def prepare_dataset_proxy( voice, language ): + return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language ) + +def update_voices(): + return ( + gr.Dropdown.update(choices=get_voice_list()), + gr.Dropdown.update(choices=get_voice_list()), + gr.Dropdown.update(choices=get_voice_list("./results/")), + ) + +def history_copy_settings( voice, file ): + return import_generate_settings( f"./results/{voice}/{file}" ) + def setup_gradio(): global args global ui @@ -279,34 +349,6 @@ def setup_gradio(): with gr.Column(): history_audio = gr.Audio() history_copy_settings_button = gr.Button(value="Copy Settings") - - def history_view_results( voice ): - results = [] - files = [] - outdir = f"./results/{voice}/" - for i, file in enumerate(sorted(os.listdir(outdir))): - if file[-4:] != ".wav": - continue - - metadata, _ = read_generate_settings(f"{outdir}/{file}", read_latents=False) - if metadata is None: - continue - - values = [] - for k in headers: - v = file - if k != "Name": - v = metadata[headers[k]] - values.append(v) - - - files.append(file) - results.append(values) - - return ( - results, - gr.Dropdown.update(choices=sorted(files)) - ) history_view_results_button.click( fn=history_view_results, @@ -335,23 +377,6 @@ def setup_gradio(): metadata_out = gr.JSON(label="Audio Metadata") latents_out = gr.File(type="binary", label="Voice Latents") - def read_generate_settings_proxy(file, saveAs='.temp'): - j, latents = read_generate_settings(file) - - if latents: - outdir = f'{get_voice_dir()}/{saveAs}/' - os.makedirs(outdir, exist_ok=True) - with open(f'{outdir}/cond_latents.pth', 'wb') as f: - f.write(latents) - - latents = f'{outdir}/cond_latents.pth' - - return ( - j, - gr.update(value=latents, visible=latents is not None), - None if j is None else j['voice'] - ) - audio_in.upload( fn=read_generate_settings_proxy, inputs=audio_in, @@ -382,9 +407,6 @@ def setup_gradio(): with gr.Column(): prepare_dataset_button = gr.Button(value="Prepare") - def prepare_dataset_proxy( voice, language ): - return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language ) - prepare_dataset_button.click( prepare_dataset_proxy, inputs=dataset_settings, @@ -416,34 +438,12 @@ def setup_gradio(): with gr.Tab("Train"): with gr.Row(): with gr.Column(): - def get_training_configs(): - configs = [] - for i, file in enumerate(sorted(os.listdir(f"./training/"))): - if file[-5:] != ".yaml" or file[0] == ".": - continue - configs.append(f"./training/{file}") - - return configs - def update_training_configs(): - return gr.update(choices=get_training_configs()) - training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs()) refresh_configs = gr.Button(value="Refresh Configurations") - train = gr.Button(value="Train") - - def run_training_proxy( config ): - global tts - del tts - - import subprocess - subprocess.run(["python", "./src/train.py", "-opt", config], env=os.environ.copy(), shell=True) - """ - from train import train - train(config) - """ + train = gr.Button(value="Train") refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs) - train.click(run_training_proxy, + train.click(run_training, inputs=training_configs, outputs=None ) @@ -506,17 +506,6 @@ def setup_gradio(): experimental_checkboxes, ] - # YUCK - def update_voices(): - return ( - gr.Dropdown.update(choices=get_voice_list()), - gr.Dropdown.update(choices=get_voice_list()), - gr.Dropdown.update(choices=get_voice_list("./results/")), - ) - - def history_copy_settings( voice, file ): - return import_generate_settings( f"./results/{voice}/{file}" ) - refresh_voices.click(update_voices, inputs=None, outputs=[