added simple websocket server which allows to start tts generation tasks, retrieving autoregressive models and voices list

This commit is contained in:
ben_mkiv 2023-08-16 12:51:13 +02:00
parent ac645e0a20
commit 2626364c40
2 changed files with 59 additions and 1 deletions

View File

@ -0,0 +1,53 @@
import asyncio
import json
from threading import Thread
from websockets.server import serve
from utils import generate, get_autoregressive_models, get_voice_list
async def _handle_generate(websocket, message):
await websocket.send(json.dumps(generate(**message)))
async def _handle_get_autoregressive_models(websocket, message):
await websocket.send(json.dumps(get_autoregressive_models()))
async def _handle_get_voice_list(websocket, message):
await websocket.send(json.dumps(get_voice_list()))
async def _handle_message(websocket, message):
if message.get('action') and message['action'] == 'generate':
await _handle_generate(websocket, message)
elif message.get('action') and message['action'] == 'get_voices':
await _handle_get_voice_list(websocket, message)
elif message.get('action') and message['action'] == 'get_autoregressive_models':
await _handle_get_autoregressive_models(websocket, message)
else:
print(message)
async def _handle_connection(websocket, path):
print("websocket: client connected")
async for message in websocket:
try:
await _handle_message(websocket, json.loads(message))
except ValueError:
print("websocket: malformed json received")
async def _run(host: str, port: int):
print("websocket: server started")
async with serve(_handle_connection, host, port, ping_interval=None):
await asyncio.Future() # run forever
def _run_server(listen_address: str, port: int):
asyncio.run(_run(host=listen_address, port=port))
def start_websocket_server(listen_address: str, port: int):
Thread(target=_run_server, args=[listen_address, port], daemon=True).start()

View File

@ -11,6 +11,9 @@ os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
from utils import * from utils import *
from webui import * from webui import *
from api.websocket_server import start_websocket_server
if __name__ == "__main__": if __name__ == "__main__":
args = setup_args() args = setup_args()
@ -23,6 +26,7 @@ if __name__ == "__main__":
if not args.defer_tts_load: if not args.defer_tts_load:
tts = load_tts() tts = load_tts()
start_websocket_server('127.0.0.1', 8069)
webui.block_thread() webui.block_thread()
elif __name__ == "main": elif __name__ == "main":
from fastapi import FastAPI from fastapi import FastAPI
@ -38,3 +42,4 @@ elif __name__ == "main":
if not args.defer_tts_load: if not args.defer_tts_load:
tts = load_tts() tts = load_tts()