From ce24ba41e23a5406c1b2dd13fdfe285e944313c7 Mon Sep 17 00:00:00 2001 From: ben_mkiv Date: Tue, 22 Aug 2023 23:09:42 +0200 Subject: [PATCH] Websocket server, override args parameters for model settings (squashed) Revert "favor existing arguments from parameters (kwargs) over global (args)" This reverts commit 89102347a956ebcfe9a83ae7d1aa1336f1c53483. args are now updated in the websocket server --- src/api/websocket_server.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/api/websocket_server.py b/src/api/websocket_server.py index 7695fbc..8e67447 100644 --- a/src/api/websocket_server.py +++ b/src/api/websocket_server.py @@ -4,8 +4,7 @@ from threading import Thread from websockets.server import serve -from utils import generate, get_autoregressive_models, get_voice_list - +from utils import generate, get_autoregressive_models, get_voice_list, args # this is a not so nice workaround to set values to None if their string value is "None" def replaceNoneStringWithNone(message): @@ -19,6 +18,21 @@ def replaceNoneStringWithNone(message): async def _handle_generate(websocket, message): + global args + + # update args parameters which control the model settings + if message.get('autoregressive_model'): + args.autoregressive_model = message['autoregressive_model'] + + if message.get('diffusion_model'): + args.diffusion_model = message['diffusion_model'] + + if message.get('tokenizer_json'): + args.tokenizer_json = message['tokenizer_json'] + + if message.get('sample_batch_size'): + args.sample_batch_size = message['sample_batch_size'] + message['result'] = generate(**message) await websocket.send(json.dumps(replaceNoneStringWithNone(message)))