diff --git a/src/api/websocket_server.py b/src/api/websocket_server.py index 3af0685..7695fbc 100644 --- a/src/api/websocket_server.py +++ b/src/api/websocket_server.py @@ -7,25 +7,30 @@ from websockets.server import serve from utils import generate, get_autoregressive_models, get_voice_list -async def _handle_generate(websocket, message): - await websocket.send(json.dumps(generate(**message))) +# this is a not so nice workaround to set values to None if their string value is "None" +def replaceNoneStringWithNone(message): + ignore_fields = ['text'] # list of fields which CAN have "None" as literal String value + for member in message: + if message[member] == 'None' and member not in ignore_fields: + message[member] = None -async def _handle_get_autoregressive_models(websocket, message): - await websocket.send(json.dumps(get_autoregressive_models())) + return message -async def _handle_get_voice_list(websocket, message): - await websocket.send(json.dumps(get_voice_list())) +async def _handle_generate(websocket, message): + message['result'] = generate(**message) + await websocket.send(json.dumps(replaceNoneStringWithNone(message))) -# this is a not so nice workaround to set values to None if their string value is "None" -def replaceNoneStringWithNone(message): - for member in message: - if message[member] == 'None': - message[member] = None +async def _handle_get_autoregressive_models(websocket, message): + message['result'] = get_autoregressive_models() + await websocket.send(json.dumps(replaceNoneStringWithNone(message))) - return message + +async def _handle_get_voice_list(websocket, message): + message['result'] = get_voice_list() + await websocket.send(json.dumps(replaceNoneStringWithNone(message))) async def _handle_message(websocket, message):