From 578a5bcadd77df6de0881532086fe116f16ea3b4 Mon Sep 17 00:00:00 2001 From: ben_mkiv Date: Sat, 26 Aug 2023 01:40:35 +0200 Subject: [PATCH] websocket server: fix for model loading (just overriding args didn't do it after all...) --- src/api/websocket_server.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/api/websocket_server.py b/src/api/websocket_server.py index 8e67447..e9facbb 100644 --- a/src/api/websocket_server.py +++ b/src/api/websocket_server.py @@ -4,7 +4,7 @@ from threading import Thread from websockets.server import serve -from utils import generate, get_autoregressive_models, get_voice_list, args +from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer # this is a not so nice workaround to set values to None if their string value is "None" def replaceNoneStringWithNone(message): @@ -18,19 +18,18 @@ def replaceNoneStringWithNone(message): async def _handle_generate(websocket, message): - global args - # update args parameters which control the model settings if message.get('autoregressive_model'): - args.autoregressive_model = message['autoregressive_model'] + update_autoregressive_model(message['autoregressive_model']) if message.get('diffusion_model'): - args.diffusion_model = message['diffusion_model'] + update_diffusion_model(message['diffusion_model']) if message.get('tokenizer_json'): - args.tokenizer_json = message['tokenizer_json'] + update_tokenizer(message['tokenizer_json']) if message.get('sample_batch_size'): + global args args.sample_batch_size = message['sample_batch_size'] message['result'] = generate(**message)