finally added model refresh button, also searches in the training folder for outputted models so you don't even need to copy them

This commit is contained in:
mrq 2023-02-24 12:58:41 +00:00
parent 1e0fec4358
commit f6d0b66e10
2 changed files with 33 additions and 9 deletions

View File

@ -838,7 +838,17 @@ def get_autoregressive_models(dir="./models/finetunes/"):
if os.path.exists(halfp):
base.append(halfp)
return base + sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ])
additionals = sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ])
found = []
for training in os.listdir(f'./training/'):
if not os.path.isdir(f'./training/{training}/') or not os.path.isdir(f'./training/{training}/models/'):
continue
#found = found + sorted([ f'./training/{training}/model/{d}' for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ])
models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ])
found = found + [ f'./training/{training}/model/{d}_gpt.pth' for d in models ]
#found.append(f'./training/{training}/model/{models[-1]}_gpt.pth')
return base + additionals + found
def get_dataset_list(dir="./training/"):
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ])

View File

@ -343,9 +343,10 @@ def setup_gradio():
voice_list = get_voice_list(append_defaults=True)
voice = gr.Dropdown(choices=voice_list, label="Voice", type="value", value=voice_list[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
refresh_voices = gr.Button(value="Refresh Voice List")
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
with gr.Row():
refresh_voices = gr.Button(value="Refresh Voice List")
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
def update_baseline_for_latents_chunks( voice ):
path = f'{get_voice_dir()}/{voice}/'
@ -524,7 +525,26 @@ def setup_gradio():
autoregressive_models = get_autoregressive_models()
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model)
with gr.Row():
autoregressive_models_update_button = gr.Button(value="Refresh Model List")
gr.Button(value="Check for Updates").click(check_for_updates)
gr.Button(value="(Re)Load TTS").click(
reload_tts,
inputs=autoregressive_model_dropdown,
outputs=None
)
def update_model_list_proxy( val ):
autoregressive_models = get_autoregressive_models()
if val not in autoregressive_models:
val = autoregressive_models[0]
return gr.update( choices=autoregressive_models, value=val )
autoregressive_models_update_button.click(
update_model_list_proxy,
inputs=autoregressive_model_dropdown,
outputs=autoregressive_model_dropdown,
)
autoregressive_model_dropdown.change(
fn=update_autoregressive_model,
@ -537,12 +557,6 @@ def setup_gradio():
outputs=None
)
gr.Button(value="Check for Updates").click(check_for_updates)
gr.Button(value="(Re)Load TTS").click(
reload_tts,
inputs=autoregressive_model_dropdown,
outputs=None
)
for i in exec_inputs:
i.change( fn=update_args, inputs=exec_inputs )