Merge pull request 'added simple websocket server which allows to start tts generation tasks, retrieving autoregressive models and voices list' (#328) from ben_mkiv/ai-voice-cloning:master into master
Reviewed-on: #328
This commit is contained in:
commit
91a0c495ff
53
src/api/websocket_server.py
Normal file
53
src/api/websocket_server.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import asyncio
|
||||
import json
|
||||
from threading import Thread
|
||||
|
||||
from websockets.server import serve
|
||||
|
||||
from utils import generate, get_autoregressive_models, get_voice_list
|
||||
|
||||
async def _handle_generate(websocket, message):
|
||||
await websocket.send(json.dumps(generate(**message)))
|
||||
|
||||
async def _handle_get_autoregressive_models(websocket, message):
|
||||
await websocket.send(json.dumps(get_autoregressive_models()))
|
||||
|
||||
async def _handle_get_voice_list(websocket, message):
|
||||
await websocket.send(json.dumps(get_voice_list()))
|
||||
|
||||
|
||||
|
||||
async def _handle_message(websocket, message):
|
||||
if message.get('action') and message['action'] == 'generate':
|
||||
await _handle_generate(websocket, message)
|
||||
elif message.get('action') and message['action'] == 'get_voices':
|
||||
await _handle_get_voice_list(websocket, message)
|
||||
elif message.get('action') and message['action'] == 'get_autoregressive_models':
|
||||
await _handle_get_autoregressive_models(websocket, message)
|
||||
else:
|
||||
print(message)
|
||||
|
||||
|
||||
|
||||
async def _handle_connection(websocket, path):
|
||||
print("websocket: client connected")
|
||||
|
||||
async for message in websocket:
|
||||
try:
|
||||
await _handle_message(websocket, json.loads(message))
|
||||
except ValueError:
|
||||
print("websocket: malformed json received")
|
||||
|
||||
|
||||
async def _run(host: str, port: int):
|
||||
print("websocket: server started")
|
||||
|
||||
async with serve(_handle_connection, host, port, ping_interval=None):
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
|
||||
def _run_server(listen_address: str, port: int):
|
||||
asyncio.run(_run(host=listen_address, port=port))
|
||||
|
||||
def start_websocket_server(listen_address: str, port: int):
|
||||
Thread(target=_run_server, args=[listen_address, port], daemon=True).start()
|
|
@ -11,6 +11,9 @@ os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
|
|||
from utils import *
|
||||
from webui import *
|
||||
|
||||
from api.websocket_server import start_websocket_server
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = setup_args()
|
||||
|
||||
|
@ -23,6 +26,7 @@ if __name__ == "__main__":
|
|||
if not args.defer_tts_load:
|
||||
tts = load_tts()
|
||||
|
||||
start_websocket_server('127.0.0.1', 8069)
|
||||
webui.block_thread()
|
||||
elif __name__ == "main":
|
||||
from fastapi import FastAPI
|
||||
|
@ -37,4 +41,5 @@ elif __name__ == "main":
|
|||
app = gr.mount_gradio_app(app, webui, path=args.listen_path)
|
||||
|
||||
if not args.defer_tts_load:
|
||||
tts = load_tts()
|
||||
tts = load_tts()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user