I forgot to make it update the whisper model at runtime

This commit is contained in:
mrq 2023-02-19 01:47:06 +00:00
parent 47058db67f
commit 5fcdb19f8b
2 changed files with 13 additions and 4 deletions

View File

@ -765,6 +765,17 @@ def get_dataset_list(dir="./training/"):
def get_training_list(dir="./training/"): def get_training_list(dir="./training/"):
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.yaml" in os.listdir(os.path.join(dir, d)) ]) return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
def update_whisper_model(name):
global whisper_model
if whisper_model:
del whisper_model
whisper_model = None
args.whisper_model = name
print(f"Loading Whisper model: {args.whisper_model}")
whisper_model = whisper.load_model(args.whisper_model)
def update_autoregressive_model(path_name): def update_autoregressive_model(path_name):
global tts global tts
if not tts: if not tts:

View File

@ -209,10 +209,8 @@ def history_copy_settings( voice, file ):
return import_generate_settings( f"./results/{voice}/{file}" ) return import_generate_settings( f"./results/{voice}/{file}" )
def update_model_settings( autoregressive_model, whisper_model ): def update_model_settings( autoregressive_model, whisper_model ):
if args.autoregressive_model != autoregressive_model:
update_autoregressive_model(autoregressive_model) update_autoregressive_model(autoregressive_model)
update_whisper_model(whisper_model)
args.whisper_model = whisper_model
save_args_settings() save_args_settings()