cleanup, "injected" dvae.pth to download through tortoise's model loader, so I don't need to keep copying it

This commit is contained in:
mrq 2023-02-17 19:06:05 +00:00
parent 13c9920b7f
commit bcec64af0f
3 changed files with 125 additions and 98 deletions

View File

@ -4,12 +4,27 @@ import argparse
import os
import sys
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
sys.path.insert(0, './dlas/codes/')
# this is also because DLAS is not written as a package in mind
# it'll gripe when it wants to import from train.py
sys.path.insert(0, './dlas/')
# for PIP, replace it with:
# sys.path.insert(0, os.path.dirname(os.path.realpath(dlas.__file__)))
# sys.path.insert(0, f"{os.path.dirname(os.path.realpath(dlas.__file__))}/../")
# don't even really bother trying to get DLAS PIP'd
# without kludge, it'll have to be accessible as `codes` and not `dlas`
from codes import train as tr
from utils import util, options as option
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py
# I'll clean it up better
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')

View File

@ -24,18 +24,20 @@ import gradio.utils
from datetime import datetime
from tortoise.api import TextToSpeech
from tortoise.api import TextToSpeech, MODELS, get_model_path
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
from tortoise.utils.text import split_and_recombine_text
from tortoise.utils.device import get_device_name, set_device_name
import whisper
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
args = None
tts = None
webui = None
voicefixer = None
dlas = None
whisper_model = None
def get_args():
global args
@ -53,7 +55,7 @@ def setup_args():
'sample-batch-size': None,
'embed-output-metadata': True,
'latents-lean-and-mean': True,
'voice-fixer': False, # I'm tired of long initialization of Colab notebooks
'voice-fixer': True,
'voice-fixer-use-cuda': True,
'force-cpu-for-conditioning-latents': False,
'device-override': None,
@ -420,22 +422,46 @@ def generate(
stats,
)
def setup_tortoise(restart=False):
global args
def run_training(config_path):
global tts
global voicefixer
del tts
tts = None
import subprocess
subprocess.run(["python", "./src/train.py", "-opt", config_path], env=os.environ.copy(), shell=True, stdout=subprocess.PIPE)
"""
from train import train
train(config)
"""
def setup_voicefixer(restart=False):
global voicefixer
if restart:
del voicefixer
voicefixer = None
if args.voice_fixer and not restart:
try:
from voicefixer import VoiceFixer
print("Initializating voice-fixer")
from voicefixer import VoiceFixer
voicefixer = VoiceFixer()
print("initialized voice-fixer")
except Exception as e:
print(f"Error occurred while tring to initialize voicefixer: {e}")
def setup_tortoise(restart=False):
global args
global tts
if args.voice_fixer and not restart:
setup_voicefixer(restart=restart)
if restart:
del tts
tts = None
print("Initializating TorToiSe...")
tts = TextToSpeech(minor_optimizations=not args.low_vram)
get_model_path('dvae.pth')
print("TorToiSe initialized, ready for generation.")
return tts
@ -461,7 +487,6 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f:
f.write(yaml)
whisper_model = None
def prepare_dataset( files, outdir, language=None ):
global whisper_model
if whisper_model is None:
@ -641,9 +666,7 @@ def check_for_updates():
return False
def reload_tts():
global tts
del tts
tts = setup_tortoise(restart=True)
setup_tortoise(restart=True)
def cancel_generate():
tortoise.api.STOP_SIGNAL = True

View File

@ -123,6 +123,76 @@ def update_presets(value):
else:
return (gr.update(), gr.update())
def get_training_configs():
configs = []
for i, file in enumerate(sorted(os.listdir(f"./training/"))):
if file[-5:] != ".yaml" or file[0] == ".":
continue
configs.append(f"./training/{file}")
return configs
def update_training_configs():
return gr.update(choices=get_training_configs())
def history_view_results( voice ):
results = []
files = []
outdir = f"./results/{voice}/"
for i, file in enumerate(sorted(os.listdir(outdir))):
if file[-4:] != ".wav":
continue
metadata, _ = read_generate_settings(f"{outdir}/{file}", read_latents=False)
if metadata is None:
continue
values = []
for k in headers:
v = file
if k != "Name":
v = metadata[headers[k]]
values.append(v)
files.append(file)
results.append(values)
return (
results,
gr.Dropdown.update(choices=sorted(files))
)
def read_generate_settings_proxy(file, saveAs='.temp'):
j, latents = read_generate_settings(file)
if latents:
outdir = f'{get_voice_dir()}/{saveAs}/'
os.makedirs(outdir, exist_ok=True)
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
f.write(latents)
latents = f'{outdir}/cond_latents.pth'
return (
j,
gr.update(value=latents, visible=latents is not None),
None if j is None else j['voice']
)
def prepare_dataset_proxy( voice, language ):
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language )
def update_voices():
return (
gr.Dropdown.update(choices=get_voice_list()),
gr.Dropdown.update(choices=get_voice_list()),
gr.Dropdown.update(choices=get_voice_list("./results/")),
)
def history_copy_settings( voice, file ):
return import_generate_settings( f"./results/{voice}/{file}" )
def setup_gradio():
global args
global ui
@ -280,34 +350,6 @@ def setup_gradio():
history_audio = gr.Audio()
history_copy_settings_button = gr.Button(value="Copy Settings")
def history_view_results( voice ):
results = []
files = []
outdir = f"./results/{voice}/"
for i, file in enumerate(sorted(os.listdir(outdir))):
if file[-4:] != ".wav":
continue
metadata, _ = read_generate_settings(f"{outdir}/{file}", read_latents=False)
if metadata is None:
continue
values = []
for k in headers:
v = file
if k != "Name":
v = metadata[headers[k]]
values.append(v)
files.append(file)
results.append(values)
return (
results,
gr.Dropdown.update(choices=sorted(files))
)
history_view_results_button.click(
fn=history_view_results,
inputs=history_voices,
@ -335,23 +377,6 @@ def setup_gradio():
metadata_out = gr.JSON(label="Audio Metadata")
latents_out = gr.File(type="binary", label="Voice Latents")
def read_generate_settings_proxy(file, saveAs='.temp'):
j, latents = read_generate_settings(file)
if latents:
outdir = f'{get_voice_dir()}/{saveAs}/'
os.makedirs(outdir, exist_ok=True)
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
f.write(latents)
latents = f'{outdir}/cond_latents.pth'
return (
j,
gr.update(value=latents, visible=latents is not None),
None if j is None else j['voice']
)
audio_in.upload(
fn=read_generate_settings_proxy,
inputs=audio_in,
@ -382,9 +407,6 @@ def setup_gradio():
with gr.Column():
prepare_dataset_button = gr.Button(value="Prepare")
def prepare_dataset_proxy( voice, language ):
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language )
prepare_dataset_button.click(
prepare_dataset_proxy,
inputs=dataset_settings,
@ -416,34 +438,12 @@ def setup_gradio():
with gr.Tab("Train"):
with gr.Row():
with gr.Column():
def get_training_configs():
configs = []
for i, file in enumerate(sorted(os.listdir(f"./training/"))):
if file[-5:] != ".yaml" or file[0] == ".":
continue
configs.append(f"./training/{file}")
return configs
def update_training_configs():
return gr.update(choices=get_training_configs())
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs())
refresh_configs = gr.Button(value="Refresh Configurations")
train = gr.Button(value="Train")
def run_training_proxy( config ):
global tts
del tts
import subprocess
subprocess.run(["python", "./src/train.py", "-opt", config], env=os.environ.copy(), shell=True)
"""
from train import train
train(config)
"""
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
train.click(run_training_proxy,
train.click(run_training,
inputs=training_configs,
outputs=None
)
@ -506,17 +506,6 @@ def setup_gradio():
experimental_checkboxes,
]
# YUCK
def update_voices():
return (
gr.Dropdown.update(choices=get_voice_list()),
gr.Dropdown.update(choices=get_voice_list()),
gr.Dropdown.update(choices=get_voice_list("./results/")),
)
def history_copy_settings( voice, file ):
return import_generate_settings( f"./results/{voice}/{file}" )
refresh_voices.click(update_voices,
inputs=None,
outputs=[