diff --git a/src/api/websocket_server.py b/src/api/websocket_server.py index 7695fbc..8e67447 100644 --- a/src/api/websocket_server.py +++ b/src/api/websocket_server.py @@ -4,8 +4,7 @@ from threading import Thread from websockets.server import serve -from utils import generate, get_autoregressive_models, get_voice_list - +from utils import generate, get_autoregressive_models, get_voice_list, args # this is a not so nice workaround to set values to None if their string value is "None" def replaceNoneStringWithNone(message): @@ -19,6 +18,21 @@ def replaceNoneStringWithNone(message): async def _handle_generate(websocket, message): + global args + + # update args parameters which control the model settings + if message.get('autoregressive_model'): + args.autoregressive_model = message['autoregressive_model'] + + if message.get('diffusion_model'): + args.diffusion_model = message['diffusion_model'] + + if message.get('tokenizer_json'): + args.tokenizer_json = message['tokenizer_json'] + + if message.get('sample_batch_size'): + args.sample_batch_size = message['sample_batch_size'] + message['result'] = generate(**message) await websocket.send(json.dumps(replaceNoneStringWithNone(message)))