forked from mrq/ai-voice-cloning
blame mrq/ai-voice-cloning#122
This commit is contained in:
parent
9238df0b03
commit
ccbf2e6aff
|
@ -35,7 +35,7 @@ from datetime import timedelta
|
||||||
from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate
|
from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate
|
||||||
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices
|
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices
|
||||||
from tortoise.utils.text import split_and_recombine_text
|
from tortoise.utils.text import split_and_recombine_text
|
||||||
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram
|
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, do_gc
|
||||||
|
|
||||||
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
||||||
|
|
||||||
|
@ -1547,13 +1547,6 @@ def get_dataset_list(dir="./training/"):
|
||||||
def get_training_list(dir="./training/"):
|
def get_training_list(dir="./training/"):
|
||||||
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
|
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||||
|
|
||||||
def do_gc():
|
|
||||||
gc.collect()
|
|
||||||
try:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
||||||
|
|
|
@ -301,6 +301,7 @@ def setup_gradio():
|
||||||
result_voices = get_voice_list("./results/")
|
result_voices = get_voice_list("./results/")
|
||||||
autoregressive_models = get_autoregressive_models()
|
autoregressive_models = get_autoregressive_models()
|
||||||
dataset_list = get_dataset_list()
|
dataset_list = get_dataset_list()
|
||||||
|
training_list = get_training_list()
|
||||||
|
|
||||||
global GENERATE_SETTINGS_ARGS
|
global GENERATE_SETTINGS_ARGS
|
||||||
GENERATE_SETTINGS_ARGS = list(inspect.signature(generate_proxy).parameters.keys())[:-1]
|
GENERATE_SETTINGS_ARGS = list(inspect.signature(generate_proxy).parameters.keys())[:-1]
|
||||||
|
@ -492,8 +493,7 @@ def setup_gradio():
|
||||||
with gr.Tab("Run Training"):
|
with gr.Tab("Run Training"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_list = get_training_list()
|
training_configs = gr.Dropdown(label="Training Configuration", choices=training_list, value=training_list[0] if len(training_list) else "")
|
||||||
training_configs = gr.Dropdown(label="Training Configuration", choices=training_list, value=training_list[0])
|
|
||||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||||
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user