cleanup, "injected" dvae.pth to download through tortoise's model loader, so I don't need to keep copying it
This commit is contained in:
parent
13c9920b7f
commit
bcec64af0f
15
src/train.py
15
src/train.py
|
@ -4,12 +4,27 @@ import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
|
||||||
|
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
|
||||||
|
|
||||||
sys.path.insert(0, './dlas/codes/')
|
sys.path.insert(0, './dlas/codes/')
|
||||||
|
# this is also because DLAS is not written as a package in mind
|
||||||
|
# it'll gripe when it wants to import from train.py
|
||||||
sys.path.insert(0, './dlas/')
|
sys.path.insert(0, './dlas/')
|
||||||
|
|
||||||
|
# for PIP, replace it with:
|
||||||
|
# sys.path.insert(0, os.path.dirname(os.path.realpath(dlas.__file__)))
|
||||||
|
# sys.path.insert(0, f"{os.path.dirname(os.path.realpath(dlas.__file__))}/../")
|
||||||
|
|
||||||
|
# don't even really bother trying to get DLAS PIP'd
|
||||||
|
# without kludge, it'll have to be accessible as `codes` and not `dlas`
|
||||||
|
|
||||||
from codes import train as tr
|
from codes import train as tr
|
||||||
from utils import util, options as option
|
from utils import util, options as option
|
||||||
|
|
||||||
|
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py
|
||||||
|
# I'll clean it up better
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
|
|
53
src/utils.py
53
src/utils.py
|
@ -24,18 +24,20 @@ import gradio.utils
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from tortoise.api import TextToSpeech
|
from tortoise.api import TextToSpeech, MODELS, get_model_path
|
||||||
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
|
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
|
||||||
from tortoise.utils.text import split_and_recombine_text
|
from tortoise.utils.text import split_and_recombine_text
|
||||||
from tortoise.utils.device import get_device_name, set_device_name
|
from tortoise.utils.device import get_device_name, set_device_name
|
||||||
|
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
|
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
||||||
|
|
||||||
args = None
|
args = None
|
||||||
tts = None
|
tts = None
|
||||||
webui = None
|
webui = None
|
||||||
voicefixer = None
|
voicefixer = None
|
||||||
dlas = None
|
whisper_model = None
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
global args
|
global args
|
||||||
|
@ -53,7 +55,7 @@ def setup_args():
|
||||||
'sample-batch-size': None,
|
'sample-batch-size': None,
|
||||||
'embed-output-metadata': True,
|
'embed-output-metadata': True,
|
||||||
'latents-lean-and-mean': True,
|
'latents-lean-and-mean': True,
|
||||||
'voice-fixer': False, # I'm tired of long initialization of Colab notebooks
|
'voice-fixer': True,
|
||||||
'voice-fixer-use-cuda': True,
|
'voice-fixer-use-cuda': True,
|
||||||
'force-cpu-for-conditioning-latents': False,
|
'force-cpu-for-conditioning-latents': False,
|
||||||
'device-override': None,
|
'device-override': None,
|
||||||
|
@ -420,22 +422,46 @@ def generate(
|
||||||
stats,
|
stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def run_training(config_path):
|
||||||
|
global tts
|
||||||
|
del tts
|
||||||
|
tts = None
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(["python", "./src/train.py", "-opt", config_path], env=os.environ.copy(), shell=True, stdout=subprocess.PIPE)
|
||||||
|
"""
|
||||||
|
from train import train
|
||||||
|
train(config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setup_voicefixer(restart=False):
|
||||||
|
global voicefixer
|
||||||
|
if restart:
|
||||||
|
del voicefixer
|
||||||
|
voicefixer = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("Initializating voice-fixer")
|
||||||
|
from voicefixer import VoiceFixer
|
||||||
|
voicefixer = VoiceFixer()
|
||||||
|
print("initialized voice-fixer")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error occurred while tring to initialize voicefixer: {e}")
|
||||||
|
|
||||||
def setup_tortoise(restart=False):
|
def setup_tortoise(restart=False):
|
||||||
global args
|
global args
|
||||||
global tts
|
global tts
|
||||||
global voicefixer
|
|
||||||
|
|
||||||
if args.voice_fixer and not restart:
|
if args.voice_fixer and not restart:
|
||||||
try:
|
setup_voicefixer(restart=restart)
|
||||||
from voicefixer import VoiceFixer
|
|
||||||
print("Initializating voice-fixer")
|
if restart:
|
||||||
voicefixer = VoiceFixer()
|
del tts
|
||||||
print("initialized voice-fixer")
|
tts = None
|
||||||
except Exception as e:
|
|
||||||
print(f"Error occurred while tring to initialize voicefixer: {e}")
|
|
||||||
|
|
||||||
print("Initializating TorToiSe...")
|
print("Initializating TorToiSe...")
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
||||||
|
get_model_path('dvae.pth')
|
||||||
print("TorToiSe initialized, ready for generation.")
|
print("TorToiSe initialized, ready for generation.")
|
||||||
return tts
|
return tts
|
||||||
|
|
||||||
|
@ -461,7 +487,6 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
|
||||||
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f:
|
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f:
|
||||||
f.write(yaml)
|
f.write(yaml)
|
||||||
|
|
||||||
whisper_model = None
|
|
||||||
def prepare_dataset( files, outdir, language=None ):
|
def prepare_dataset( files, outdir, language=None ):
|
||||||
global whisper_model
|
global whisper_model
|
||||||
if whisper_model is None:
|
if whisper_model is None:
|
||||||
|
@ -641,9 +666,7 @@ def check_for_updates():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def reload_tts():
|
def reload_tts():
|
||||||
global tts
|
setup_tortoise(restart=True)
|
||||||
del tts
|
|
||||||
tts = setup_tortoise(restart=True)
|
|
||||||
|
|
||||||
def cancel_generate():
|
def cancel_generate():
|
||||||
tortoise.api.STOP_SIGNAL = True
|
tortoise.api.STOP_SIGNAL = True
|
||||||
|
|
155
src/webui.py
155
src/webui.py
|
@ -123,6 +123,76 @@ def update_presets(value):
|
||||||
else:
|
else:
|
||||||
return (gr.update(), gr.update())
|
return (gr.update(), gr.update())
|
||||||
|
|
||||||
|
def get_training_configs():
|
||||||
|
configs = []
|
||||||
|
for i, file in enumerate(sorted(os.listdir(f"./training/"))):
|
||||||
|
if file[-5:] != ".yaml" or file[0] == ".":
|
||||||
|
continue
|
||||||
|
configs.append(f"./training/{file}")
|
||||||
|
|
||||||
|
return configs
|
||||||
|
|
||||||
|
def update_training_configs():
|
||||||
|
return gr.update(choices=get_training_configs())
|
||||||
|
|
||||||
|
def history_view_results( voice ):
|
||||||
|
results = []
|
||||||
|
files = []
|
||||||
|
outdir = f"./results/{voice}/"
|
||||||
|
for i, file in enumerate(sorted(os.listdir(outdir))):
|
||||||
|
if file[-4:] != ".wav":
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata, _ = read_generate_settings(f"{outdir}/{file}", read_latents=False)
|
||||||
|
if metadata is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
values = []
|
||||||
|
for k in headers:
|
||||||
|
v = file
|
||||||
|
if k != "Name":
|
||||||
|
v = metadata[headers[k]]
|
||||||
|
values.append(v)
|
||||||
|
|
||||||
|
|
||||||
|
files.append(file)
|
||||||
|
results.append(values)
|
||||||
|
|
||||||
|
return (
|
||||||
|
results,
|
||||||
|
gr.Dropdown.update(choices=sorted(files))
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||||
|
j, latents = read_generate_settings(file)
|
||||||
|
|
||||||
|
if latents:
|
||||||
|
outdir = f'{get_voice_dir()}/{saveAs}/'
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
|
||||||
|
f.write(latents)
|
||||||
|
|
||||||
|
latents = f'{outdir}/cond_latents.pth'
|
||||||
|
|
||||||
|
return (
|
||||||
|
j,
|
||||||
|
gr.update(value=latents, visible=latents is not None),
|
||||||
|
None if j is None else j['voice']
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_dataset_proxy( voice, language ):
|
||||||
|
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language )
|
||||||
|
|
||||||
|
def update_voices():
|
||||||
|
return (
|
||||||
|
gr.Dropdown.update(choices=get_voice_list()),
|
||||||
|
gr.Dropdown.update(choices=get_voice_list()),
|
||||||
|
gr.Dropdown.update(choices=get_voice_list("./results/")),
|
||||||
|
)
|
||||||
|
|
||||||
|
def history_copy_settings( voice, file ):
|
||||||
|
return import_generate_settings( f"./results/{voice}/{file}" )
|
||||||
|
|
||||||
def setup_gradio():
|
def setup_gradio():
|
||||||
global args
|
global args
|
||||||
global ui
|
global ui
|
||||||
|
@ -279,34 +349,6 @@ def setup_gradio():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
history_audio = gr.Audio()
|
history_audio = gr.Audio()
|
||||||
history_copy_settings_button = gr.Button(value="Copy Settings")
|
history_copy_settings_button = gr.Button(value="Copy Settings")
|
||||||
|
|
||||||
def history_view_results( voice ):
|
|
||||||
results = []
|
|
||||||
files = []
|
|
||||||
outdir = f"./results/{voice}/"
|
|
||||||
for i, file in enumerate(sorted(os.listdir(outdir))):
|
|
||||||
if file[-4:] != ".wav":
|
|
||||||
continue
|
|
||||||
|
|
||||||
metadata, _ = read_generate_settings(f"{outdir}/{file}", read_latents=False)
|
|
||||||
if metadata is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
values = []
|
|
||||||
for k in headers:
|
|
||||||
v = file
|
|
||||||
if k != "Name":
|
|
||||||
v = metadata[headers[k]]
|
|
||||||
values.append(v)
|
|
||||||
|
|
||||||
|
|
||||||
files.append(file)
|
|
||||||
results.append(values)
|
|
||||||
|
|
||||||
return (
|
|
||||||
results,
|
|
||||||
gr.Dropdown.update(choices=sorted(files))
|
|
||||||
)
|
|
||||||
|
|
||||||
history_view_results_button.click(
|
history_view_results_button.click(
|
||||||
fn=history_view_results,
|
fn=history_view_results,
|
||||||
|
@ -335,23 +377,6 @@ def setup_gradio():
|
||||||
metadata_out = gr.JSON(label="Audio Metadata")
|
metadata_out = gr.JSON(label="Audio Metadata")
|
||||||
latents_out = gr.File(type="binary", label="Voice Latents")
|
latents_out = gr.File(type="binary", label="Voice Latents")
|
||||||
|
|
||||||
def read_generate_settings_proxy(file, saveAs='.temp'):
|
|
||||||
j, latents = read_generate_settings(file)
|
|
||||||
|
|
||||||
if latents:
|
|
||||||
outdir = f'{get_voice_dir()}/{saveAs}/'
|
|
||||||
os.makedirs(outdir, exist_ok=True)
|
|
||||||
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
|
|
||||||
f.write(latents)
|
|
||||||
|
|
||||||
latents = f'{outdir}/cond_latents.pth'
|
|
||||||
|
|
||||||
return (
|
|
||||||
j,
|
|
||||||
gr.update(value=latents, visible=latents is not None),
|
|
||||||
None if j is None else j['voice']
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_in.upload(
|
audio_in.upload(
|
||||||
fn=read_generate_settings_proxy,
|
fn=read_generate_settings_proxy,
|
||||||
inputs=audio_in,
|
inputs=audio_in,
|
||||||
|
@ -382,9 +407,6 @@ def setup_gradio():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
prepare_dataset_button = gr.Button(value="Prepare")
|
prepare_dataset_button = gr.Button(value="Prepare")
|
||||||
|
|
||||||
def prepare_dataset_proxy( voice, language ):
|
|
||||||
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language )
|
|
||||||
|
|
||||||
prepare_dataset_button.click(
|
prepare_dataset_button.click(
|
||||||
prepare_dataset_proxy,
|
prepare_dataset_proxy,
|
||||||
inputs=dataset_settings,
|
inputs=dataset_settings,
|
||||||
|
@ -416,34 +438,12 @@ def setup_gradio():
|
||||||
with gr.Tab("Train"):
|
with gr.Tab("Train"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
def get_training_configs():
|
|
||||||
configs = []
|
|
||||||
for i, file in enumerate(sorted(os.listdir(f"./training/"))):
|
|
||||||
if file[-5:] != ".yaml" or file[0] == ".":
|
|
||||||
continue
|
|
||||||
configs.append(f"./training/{file}")
|
|
||||||
|
|
||||||
return configs
|
|
||||||
def update_training_configs():
|
|
||||||
return gr.update(choices=get_training_configs())
|
|
||||||
|
|
||||||
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs())
|
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs())
|
||||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||||
train = gr.Button(value="Train")
|
train = gr.Button(value="Train")
|
||||||
|
|
||||||
def run_training_proxy( config ):
|
|
||||||
global tts
|
|
||||||
del tts
|
|
||||||
|
|
||||||
import subprocess
|
|
||||||
subprocess.run(["python", "./src/train.py", "-opt", config], env=os.environ.copy(), shell=True)
|
|
||||||
"""
|
|
||||||
from train import train
|
|
||||||
train(config)
|
|
||||||
"""
|
|
||||||
|
|
||||||
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
|
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
|
||||||
train.click(run_training_proxy,
|
train.click(run_training,
|
||||||
inputs=training_configs,
|
inputs=training_configs,
|
||||||
outputs=None
|
outputs=None
|
||||||
)
|
)
|
||||||
|
@ -506,17 +506,6 @@ def setup_gradio():
|
||||||
experimental_checkboxes,
|
experimental_checkboxes,
|
||||||
]
|
]
|
||||||
|
|
||||||
# YUCK
|
|
||||||
def update_voices():
|
|
||||||
return (
|
|
||||||
gr.Dropdown.update(choices=get_voice_list()),
|
|
||||||
gr.Dropdown.update(choices=get_voice_list()),
|
|
||||||
gr.Dropdown.update(choices=get_voice_list("./results/")),
|
|
||||||
)
|
|
||||||
|
|
||||||
def history_copy_settings( voice, file ):
|
|
||||||
return import_generate_settings( f"./results/{voice}/{file}" )
|
|
||||||
|
|
||||||
refresh_voices.click(update_voices,
|
refresh_voices.click(update_voices,
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=[
|
outputs=[
|
||||||
|
|
Loading…
Reference in New Issue
Block a user