diff --git a/src/api/websocket_server.py b/src/api/websocket_server.py index 1349263..3af0685 100644 --- a/src/api/websocket_server.py +++ b/src/api/websocket_server.py @@ -6,18 +6,31 @@ from websockets.server import serve from utils import generate, get_autoregressive_models, get_voice_list + async def _handle_generate(websocket, message): await websocket.send(json.dumps(generate(**message))) + async def _handle_get_autoregressive_models(websocket, message): await websocket.send(json.dumps(get_autoregressive_models())) + async def _handle_get_voice_list(websocket, message): await websocket.send(json.dumps(get_voice_list())) +# this is a not so nice workaround to set values to None if their string value is "None" +def replaceNoneStringWithNone(message): + for member in message: + if message[member] == 'None': + message[member] = None + + return message + async def _handle_message(websocket, message): + message = replaceNoneStringWithNone(message) + if message.get('action') and message['action'] == 'generate': await _handle_generate(websocket, message) elif message.get('action') and message['action'] == 'get_voices': @@ -25,8 +38,7 @@ async def _handle_message(websocket, message): elif message.get('action') and message['action'] == 'get_autoregressive_models': await _handle_get_autoregressive_models(websocket, message) else: - print(message) - + print("websocket: undhandled message: " + message) async def _handle_connection(websocket, path): @@ -49,5 +61,6 @@ async def _run(host: str, port: int): def _run_server(listen_address: str, port: int): asyncio.run(_run(host=listen_address, port=port)) + def start_websocket_server(listen_address: str, port: int): Thread(target=_run_server, args=[listen_address, port], daemon=True).start()