From f6d0b66e1009d5a8c797219761eed1033a4735ce Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 24 Feb 2023 12:58:41 +0000 Subject: [PATCH] finally added model refresh button, also searches in the training folder for outputted models so you don't even need to copy them --- src/utils.py | 12 +++++++++++- src/webui.py | 30 ++++++++++++++++++++++-------- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/utils.py b/src/utils.py index f80e231..162361e 100755 --- a/src/utils.py +++ b/src/utils.py @@ -838,7 +838,17 @@ def get_autoregressive_models(dir="./models/finetunes/"): if os.path.exists(halfp): base.append(halfp) - return base + sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ]) + additionals = sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ]) + found = [] + for training in os.listdir(f'./training/'): + if not os.path.isdir(f'./training/{training}/') or not os.path.isdir(f'./training/{training}/models/'): + continue + #found = found + sorted([ f'./training/{training}/model/{d}' for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ]) + models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ]) + found = found + [ f'./training/{training}/model/{d}_gpt.pth' for d in models ] + #found.append(f'./training/{training}/model/{models[-1]}_gpt.pth') + + return base + additionals + found def get_dataset_list(dir="./training/"): return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ]) diff --git a/src/webui.py b/src/webui.py index a9aaf2a..a2e306e 100755 --- a/src/webui.py +++ b/src/webui.py @@ -343,9 +343,10 @@ def setup_gradio(): voice_list = get_voice_list(append_defaults=True) voice = gr.Dropdown(choices=voice_list, label="Voice", type="value", value=voice_list[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" ) - refresh_voices = gr.Button(value="Refresh Voice List") voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1) - recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents") + with gr.Row(): + refresh_voices = gr.Button(value="Refresh Voice List") + recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents") def update_baseline_for_latents_chunks( voice ): path = f'{get_voice_dir()}/{voice}/' @@ -524,7 +525,26 @@ def setup_gradio(): autoregressive_models = get_autoregressive_models() autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0]) whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model) + with gr.Row(): + autoregressive_models_update_button = gr.Button(value="Refresh Model List") + gr.Button(value="Check for Updates").click(check_for_updates) + gr.Button(value="(Re)Load TTS").click( + reload_tts, + inputs=autoregressive_model_dropdown, + outputs=None + ) + def update_model_list_proxy( val ): + autoregressive_models = get_autoregressive_models() + if val not in autoregressive_models: + val = autoregressive_models[0] + return gr.update( choices=autoregressive_models, value=val ) + + autoregressive_models_update_button.click( + update_model_list_proxy, + inputs=autoregressive_model_dropdown, + outputs=autoregressive_model_dropdown, + ) autoregressive_model_dropdown.change( fn=update_autoregressive_model, @@ -537,12 +557,6 @@ def setup_gradio(): outputs=None ) - gr.Button(value="Check for Updates").click(check_for_updates) - gr.Button(value="(Re)Load TTS").click( - reload_tts, - inputs=autoregressive_model_dropdown, - outputs=None - ) for i in exec_inputs: i.change( fn=update_args, inputs=exec_inputs )