This commit is contained in:
mrq 2023-03-12 17:51:52 +00:00
parent 9238df0b03
commit ccbf2e6aff
2 changed files with 3 additions and 10 deletions

View File

@ -35,7 +35,7 @@ from datetime import timedelta
from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices
from tortoise.utils.text import split_and_recombine_text from tortoise.utils.text import split_and_recombine_text
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, do_gc
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
@ -1547,13 +1547,6 @@ def get_dataset_list(dir="./training/"):
def get_training_list(dir="./training/"): def get_training_list(dir="./training/"):
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ]) return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
def do_gc():
gc.collect()
try:
torch.cuda.empty_cache()
except Exception as e:
pass
def pad(num, zeroes): def pad(num, zeroes):
return str(num).zfill(zeroes+1) return str(num).zfill(zeroes+1)

View File

@ -301,6 +301,7 @@ def setup_gradio():
result_voices = get_voice_list("./results/") result_voices = get_voice_list("./results/")
autoregressive_models = get_autoregressive_models() autoregressive_models = get_autoregressive_models()
dataset_list = get_dataset_list() dataset_list = get_dataset_list()
training_list = get_training_list()
global GENERATE_SETTINGS_ARGS global GENERATE_SETTINGS_ARGS
GENERATE_SETTINGS_ARGS = list(inspect.signature(generate_proxy).parameters.keys())[:-1] GENERATE_SETTINGS_ARGS = list(inspect.signature(generate_proxy).parameters.keys())[:-1]
@ -492,8 +493,7 @@ def setup_gradio():
with gr.Tab("Run Training"): with gr.Tab("Run Training"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
training_list = get_training_list() training_configs = gr.Dropdown(label="Training Configuration", choices=training_list, value=training_list[0] if len(training_list) else "")
training_configs = gr.Dropdown(label="Training Configuration", choices=training_list, value=training_list[0])
refresh_configs = gr.Button(value="Refresh Configurations") refresh_configs = gr.Button(value="Refresh Configurations")
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)