diff --git a/src/api/websocket_server.py b/src/api/websocket_server.py new file mode 100644 index 0000000..1349263 --- /dev/null +++ b/src/api/websocket_server.py @@ -0,0 +1,53 @@ +import asyncio +import json +from threading import Thread + +from websockets.server import serve + +from utils import generate, get_autoregressive_models, get_voice_list + +async def _handle_generate(websocket, message): + await websocket.send(json.dumps(generate(**message))) + +async def _handle_get_autoregressive_models(websocket, message): + await websocket.send(json.dumps(get_autoregressive_models())) + +async def _handle_get_voice_list(websocket, message): + await websocket.send(json.dumps(get_voice_list())) + + + +async def _handle_message(websocket, message): + if message.get('action') and message['action'] == 'generate': + await _handle_generate(websocket, message) + elif message.get('action') and message['action'] == 'get_voices': + await _handle_get_voice_list(websocket, message) + elif message.get('action') and message['action'] == 'get_autoregressive_models': + await _handle_get_autoregressive_models(websocket, message) + else: + print(message) + + + +async def _handle_connection(websocket, path): + print("websocket: client connected") + + async for message in websocket: + try: + await _handle_message(websocket, json.loads(message)) + except ValueError: + print("websocket: malformed json received") + + +async def _run(host: str, port: int): + print("websocket: server started") + + async with serve(_handle_connection, host, port, ping_interval=None): + await asyncio.Future() # run forever + + +def _run_server(listen_address: str, port: int): + asyncio.run(_run(host=listen_address, port=port)) + +def start_websocket_server(listen_address: str, port: int): + Thread(target=_run_server, args=[listen_address, port], daemon=True).start() diff --git a/src/main.py b/src/main.py index d648b02..418d616 100755 --- a/src/main.py +++ b/src/main.py @@ -11,6 +11,9 @@ os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' from utils import * from webui import * +from api.websocket_server import start_websocket_server + + if __name__ == "__main__": args = setup_args() @@ -23,6 +26,7 @@ if __name__ == "__main__": if not args.defer_tts_load: tts = load_tts() + start_websocket_server('127.0.0.1', 8069) webui.block_thread() elif __name__ == "main": from fastapi import FastAPI @@ -37,4 +41,5 @@ elif __name__ == "main": app = gr.mount_gradio_app(app, webui, path=args.listen_path) if not args.defer_tts_load: - tts = load_tts() \ No newline at end of file + tts = load_tts() +