forked from mrq/ai-voice-cloning
Merge pull request 'websocket server: small fix' (#333) from ben_mkiv/ai-voice-cloning:master into master
Reviewed-on: mrq/ai-voice-cloning#333
This commit is contained in:
commit
fb1cfd059f
|
@ -6,18 +6,31 @@ from websockets.server import serve
|
||||||
|
|
||||||
from utils import generate, get_autoregressive_models, get_voice_list
|
from utils import generate, get_autoregressive_models, get_voice_list
|
||||||
|
|
||||||
|
|
||||||
async def _handle_generate(websocket, message):
|
async def _handle_generate(websocket, message):
|
||||||
await websocket.send(json.dumps(generate(**message)))
|
await websocket.send(json.dumps(generate(**message)))
|
||||||
|
|
||||||
|
|
||||||
async def _handle_get_autoregressive_models(websocket, message):
|
async def _handle_get_autoregressive_models(websocket, message):
|
||||||
await websocket.send(json.dumps(get_autoregressive_models()))
|
await websocket.send(json.dumps(get_autoregressive_models()))
|
||||||
|
|
||||||
|
|
||||||
async def _handle_get_voice_list(websocket, message):
|
async def _handle_get_voice_list(websocket, message):
|
||||||
await websocket.send(json.dumps(get_voice_list()))
|
await websocket.send(json.dumps(get_voice_list()))
|
||||||
|
|
||||||
|
|
||||||
|
# this is a not so nice workaround to set values to None if their string value is "None"
|
||||||
|
def replaceNoneStringWithNone(message):
|
||||||
|
for member in message:
|
||||||
|
if message[member] == 'None':
|
||||||
|
message[member] = None
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
async def _handle_message(websocket, message):
|
async def _handle_message(websocket, message):
|
||||||
|
message = replaceNoneStringWithNone(message)
|
||||||
|
|
||||||
if message.get('action') and message['action'] == 'generate':
|
if message.get('action') and message['action'] == 'generate':
|
||||||
await _handle_generate(websocket, message)
|
await _handle_generate(websocket, message)
|
||||||
elif message.get('action') and message['action'] == 'get_voices':
|
elif message.get('action') and message['action'] == 'get_voices':
|
||||||
|
@ -25,8 +38,7 @@ async def _handle_message(websocket, message):
|
||||||
elif message.get('action') and message['action'] == 'get_autoregressive_models':
|
elif message.get('action') and message['action'] == 'get_autoregressive_models':
|
||||||
await _handle_get_autoregressive_models(websocket, message)
|
await _handle_get_autoregressive_models(websocket, message)
|
||||||
else:
|
else:
|
||||||
print(message)
|
print("websocket: undhandled message: " + message)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_connection(websocket, path):
|
async def _handle_connection(websocket, path):
|
||||||
|
@ -49,5 +61,6 @@ async def _run(host: str, port: int):
|
||||||
def _run_server(listen_address: str, port: int):
|
def _run_server(listen_address: str, port: int):
|
||||||
asyncio.run(_run(host=listen_address, port=port))
|
asyncio.run(_run(host=listen_address, port=port))
|
||||||
|
|
||||||
|
|
||||||
def start_websocket_server(listen_address: str, port: int):
|
def start_websocket_server(listen_address: str, port: int):
|
||||||
Thread(target=_run_server, args=[listen_address, port], daemon=True).start()
|
Thread(target=_run_server, args=[listen_address, port], daemon=True).start()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user