cleanup, "injected" dvae.pth to download through tortoise's model loader, so I don't need to keep copying it
This commit is contained in:
parent
13c9920b7f
commit
bcec64af0f
15
src/train.py
15
src/train.py
|
@ -4,12 +4,27 @@ import argparse
|
|||
import os
|
||||
import sys
|
||||
|
||||
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
|
||||
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
|
||||
|
||||
sys.path.insert(0, './dlas/codes/')
|
||||
# this is also because DLAS is not written as a package in mind
|
||||
# it'll gripe when it wants to import from train.py
|
||||
sys.path.insert(0, './dlas/')
|
||||
|
||||
# for PIP, replace it with:
|
||||
# sys.path.insert(0, os.path.dirname(os.path.realpath(dlas.__file__)))
|
||||
# sys.path.insert(0, f"{os.path.dirname(os.path.realpath(dlas.__file__))}/../")
|
||||
|
||||
# don't even really bother trying to get DLAS PIP'd
|
||||
# without kludge, it'll have to be accessible as `codes` and not `dlas`
|
||||
|
||||
from codes import train as tr
|
||||
from utils import util, options as option
|
||||
|
||||
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py
|
||||
# I'll clean it up better
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
|
|
53
src/utils.py
53
src/utils.py
|
@ -24,18 +24,20 @@ import gradio.utils
|
|||
|
||||
from datetime import datetime
|
||||
|
||||
from tortoise.api import TextToSpeech
|
||||
from tortoise.api import TextToSpeech, MODELS, get_model_path
|
||||
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
|
||||
from tortoise.utils.text import split_and_recombine_text
|
||||
from tortoise.utils.device import get_device_name, set_device_name
|
||||
|
||||
import whisper
|
||||
|
||||
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
||||
|
||||
args = None
|
||||
tts = None
|
||||
webui = None
|
||||
voicefixer = None
|
||||
dlas = None
|
||||
whisper_model = None
|
||||
|
||||
def get_args():
|
||||
global args
|
||||
|
@ -53,7 +55,7 @@ def setup_args():
|
|||
'sample-batch-size': None,
|
||||
'embed-output-metadata': True,
|
||||
'latents-lean-and-mean': True,
|
||||
'voice-fixer': False, # I'm tired of long initialization of Colab notebooks
|
||||
'voice-fixer': True,
|
||||
'voice-fixer-use-cuda': True,
|
||||
'force-cpu-for-conditioning-latents': False,
|
||||
'device-override': None,
|
||||
|
@ -420,22 +422,46 @@ def generate(
|
|||
stats,
|
||||
)
|
||||
|
||||
def run_training(config_path):
|
||||
global tts
|
||||
del tts
|
||||
tts = None
|
||||
|
||||
import subprocess
|
||||
subprocess.run(["python", "./src/train.py", "-opt", config_path], env=os.environ.copy(), shell=True, stdout=subprocess.PIPE)
|
||||
"""
|
||||
from train import train
|
||||
train(config)
|
||||
"""
|
||||
|
||||
def setup_voicefixer(restart=False):
|
||||
global voicefixer
|
||||
if restart:
|
||||
del voicefixer
|
||||
voicefixer = None
|
||||
|
||||
try:
|
||||
print("Initializating voice-fixer")
|
||||
from voicefixer import VoiceFixer
|
||||
voicefixer = VoiceFixer()
|
||||
print("initialized voice-fixer")
|
||||
except Exception as e:
|
||||
print(f"Error occurred while tring to initialize voicefixer: {e}")
|
||||
|
||||
def setup_tortoise(restart=False):
|
||||
global args
|
||||
global tts
|
||||
global voicefixer
|
||||
|
||||
if args.voice_fixer and not restart:
|
||||
try:
|
||||
from voicefixer import VoiceFixer
|
||||
print("Initializating voice-fixer")
|
||||
voicefixer = VoiceFixer()
|
||||
print("initialized voice-fixer")
|
||||
except Exception as e:
|
||||
print(f"Error occurred while tring to initialize voicefixer: {e}")
|
||||
setup_voicefixer(restart=restart)
|
||||
|
||||
if restart:
|
||||
del tts
|
||||
tts = None
|
||||
|
||||
print("Initializating TorToiSe...")
|
||||
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
||||
get_model_path('dvae.pth')
|
||||
print("TorToiSe initialized, ready for generation.")
|
||||
return tts
|
||||
|
||||
|
@ -461,7 +487,6 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
|
|||
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f:
|
||||
f.write(yaml)
|
||||
|
||||
whisper_model = None
|
||||
def prepare_dataset( files, outdir, language=None ):
|
||||
global whisper_model
|
||||
if whisper_model is None:
|
||||
|
@ -641,9 +666,7 @@ def check_for_updates():
|
|||
return False
|
||||
|
||||
def reload_tts():
|
||||
global tts
|
||||
del tts
|
||||
tts = setup_tortoise(restart=True)
|
||||
setup_tortoise(restart=True)
|
||||
|
||||
def cancel_generate():
|
||||
tortoise.api.STOP_SIGNAL = True
|
||||
|
|
153
src/webui.py
153
src/webui.py
|
@ -123,6 +123,76 @@ def update_presets(value):
|
|||
else:
|
||||
return (gr.update(), gr.update())
|
||||
|
||||
def get_training_configs():
|
||||
configs = []
|
||||
for i, file in enumerate(sorted(os.listdir(f"./training/"))):
|
||||
if file[-5:] != ".yaml" or file[0] == ".":
|
||||
continue
|
||||
configs.append(f"./training/{file}")
|
||||
|
||||
return configs
|
||||
|
||||
def update_training_configs():
|
||||
return gr.update(choices=get_training_configs())
|
||||
|
||||
def history_view_results( voice ):
|
||||
results = []
|
||||
files = []
|
||||
outdir = f"./results/{voice}/"
|
||||
for i, file in enumerate(sorted(os.listdir(outdir))):
|
||||
if file[-4:] != ".wav":
|
||||
continue
|
||||
|
||||
metadata, _ = read_generate_settings(f"{outdir}/{file}", read_latents=False)
|
||||
if metadata is None:
|
||||
continue
|
||||
|
||||
values = []
|
||||
for k in headers:
|
||||
v = file
|
||||
if k != "Name":
|
||||
v = metadata[headers[k]]
|
||||
values.append(v)
|
||||
|
||||
|
||||
files.append(file)
|
||||
results.append(values)
|
||||
|
||||
return (
|
||||
results,
|
||||
gr.Dropdown.update(choices=sorted(files))
|
||||
)
|
||||
|
||||
def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||
j, latents = read_generate_settings(file)
|
||||
|
||||
if latents:
|
||||
outdir = f'{get_voice_dir()}/{saveAs}/'
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
|
||||
f.write(latents)
|
||||
|
||||
latents = f'{outdir}/cond_latents.pth'
|
||||
|
||||
return (
|
||||
j,
|
||||
gr.update(value=latents, visible=latents is not None),
|
||||
None if j is None else j['voice']
|
||||
)
|
||||
|
||||
def prepare_dataset_proxy( voice, language ):
|
||||
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language )
|
||||
|
||||
def update_voices():
|
||||
return (
|
||||
gr.Dropdown.update(choices=get_voice_list()),
|
||||
gr.Dropdown.update(choices=get_voice_list()),
|
||||
gr.Dropdown.update(choices=get_voice_list("./results/")),
|
||||
)
|
||||
|
||||
def history_copy_settings( voice, file ):
|
||||
return import_generate_settings( f"./results/{voice}/{file}" )
|
||||
|
||||
def setup_gradio():
|
||||
global args
|
||||
global ui
|
||||
|
@ -280,34 +350,6 @@ def setup_gradio():
|
|||
history_audio = gr.Audio()
|
||||
history_copy_settings_button = gr.Button(value="Copy Settings")
|
||||
|
||||
def history_view_results( voice ):
|
||||
results = []
|
||||
files = []
|
||||
outdir = f"./results/{voice}/"
|
||||
for i, file in enumerate(sorted(os.listdir(outdir))):
|
||||
if file[-4:] != ".wav":
|
||||
continue
|
||||
|
||||
metadata, _ = read_generate_settings(f"{outdir}/{file}", read_latents=False)
|
||||
if metadata is None:
|
||||
continue
|
||||
|
||||
values = []
|
||||
for k in headers:
|
||||
v = file
|
||||
if k != "Name":
|
||||
v = metadata[headers[k]]
|
||||
values.append(v)
|
||||
|
||||
|
||||
files.append(file)
|
||||
results.append(values)
|
||||
|
||||
return (
|
||||
results,
|
||||
gr.Dropdown.update(choices=sorted(files))
|
||||
)
|
||||
|
||||
history_view_results_button.click(
|
||||
fn=history_view_results,
|
||||
inputs=history_voices,
|
||||
|
@ -335,23 +377,6 @@ def setup_gradio():
|
|||
metadata_out = gr.JSON(label="Audio Metadata")
|
||||
latents_out = gr.File(type="binary", label="Voice Latents")
|
||||
|
||||
def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||
j, latents = read_generate_settings(file)
|
||||
|
||||
if latents:
|
||||
outdir = f'{get_voice_dir()}/{saveAs}/'
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
|
||||
f.write(latents)
|
||||
|
||||
latents = f'{outdir}/cond_latents.pth'
|
||||
|
||||
return (
|
||||
j,
|
||||
gr.update(value=latents, visible=latents is not None),
|
||||
None if j is None else j['voice']
|
||||
)
|
||||
|
||||
audio_in.upload(
|
||||
fn=read_generate_settings_proxy,
|
||||
inputs=audio_in,
|
||||
|
@ -382,9 +407,6 @@ def setup_gradio():
|
|||
with gr.Column():
|
||||
prepare_dataset_button = gr.Button(value="Prepare")
|
||||
|
||||
def prepare_dataset_proxy( voice, language ):
|
||||
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language )
|
||||
|
||||
prepare_dataset_button.click(
|
||||
prepare_dataset_proxy,
|
||||
inputs=dataset_settings,
|
||||
|
@ -416,34 +438,12 @@ def setup_gradio():
|
|||
with gr.Tab("Train"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
def get_training_configs():
|
||||
configs = []
|
||||
for i, file in enumerate(sorted(os.listdir(f"./training/"))):
|
||||
if file[-5:] != ".yaml" or file[0] == ".":
|
||||
continue
|
||||
configs.append(f"./training/{file}")
|
||||
|
||||
return configs
|
||||
def update_training_configs():
|
||||
return gr.update(choices=get_training_configs())
|
||||
|
||||
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs())
|
||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||
train = gr.Button(value="Train")
|
||||
|
||||
def run_training_proxy( config ):
|
||||
global tts
|
||||
del tts
|
||||
|
||||
import subprocess
|
||||
subprocess.run(["python", "./src/train.py", "-opt", config], env=os.environ.copy(), shell=True)
|
||||
"""
|
||||
from train import train
|
||||
train(config)
|
||||
"""
|
||||
|
||||
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
|
||||
train.click(run_training_proxy,
|
||||
train.click(run_training,
|
||||
inputs=training_configs,
|
||||
outputs=None
|
||||
)
|
||||
|
@ -506,17 +506,6 @@ def setup_gradio():
|
|||
experimental_checkboxes,
|
||||
]
|
||||
|
||||
# YUCK
|
||||
def update_voices():
|
||||
return (
|
||||
gr.Dropdown.update(choices=get_voice_list()),
|
||||
gr.Dropdown.update(choices=get_voice_list()),
|
||||
gr.Dropdown.update(choices=get_voice_list("./results/")),
|
||||
)
|
||||
|
||||
def history_copy_settings( voice, file ):
|
||||
return import_generate_settings( f"./results/{voice}/{file}" )
|
||||
|
||||
refresh_voices.click(update_voices,
|
||||
inputs=None,
|
||||
outputs=[
|
||||
|
|
Loading…
Reference in New Issue
Block a user