From b72f2216bf466ca4de484dd19736dcefc52e0eb3 Mon Sep 17 00:00:00 2001 From: ben_mkiv Date: Sat, 26 Aug 2023 17:38:58 +0200 Subject: [PATCH] added websocket server arguments to enabled it (now disabled by default) and to specify the address/port to listen on --- src/api/websocket_server.py | 4 ++-- src/main.py | 4 +++- src/utils.py | 8 ++++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/api/websocket_server.py b/src/api/websocket_server.py index e9facbb..d11b387 100644 --- a/src/api/websocket_server.py +++ b/src/api/websocket_server.py @@ -4,7 +4,7 @@ from threading import Thread from websockets.server import serve -from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer +from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer, tts # this is a not so nice workaround to set values to None if their string value is "None" def replaceNoneStringWithNone(message): @@ -70,7 +70,7 @@ async def _handle_connection(websocket, path): async def _run(host: str, port: int): - print("websocket: server started") + print(f"websocket: server started on ws://{host}:{port}") async with serve(_handle_connection, host, port, ping_interval=None): await asyncio.Future() # run forever diff --git a/src/main.py b/src/main.py index 418d616..dbb8d26 100755 --- a/src/main.py +++ b/src/main.py @@ -26,7 +26,9 @@ if __name__ == "__main__": if not args.defer_tts_load: tts = load_tts() - start_websocket_server('127.0.0.1', 8069) + if args.websocket_enabled: + start_websocket_server(args.websocket_listen_address, args.websocket_listen_port) + webui.block_thread() elif __name__ == "main": from fastapi import FastAPI diff --git a/src/utils.py b/src/utils.py index 00d9c24..d4fbb3e 100755 --- a/src/utils.py +++ b/src/utils.py @@ -3300,6 +3300,10 @@ def setup_args(cli=False): 'training-default-halfp': False, 'training-default-bnb': True, + + 'websocket-listen-address': "127.0.0.1", + 'websocket-listen-port': 8069, + 'websocket-enabled': False } if os.path.isfile('./config/exec.json'): @@ -3352,6 +3356,10 @@ def setup_args(cli=False): parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp") parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb") + + parser.add_argument("--websocket-listen-port", type=int, default=default_arguments['websocket-listen-port'], help="Websocket server listen port, default: 8069") + parser.add_argument("--websocket-listen-address", default=default_arguments['websocket-listen-address'], help="Websocket server listen address, default: 127.0.0.1") + parser.add_argument("--websocket-enabled", action='store_true', default=default_arguments['websocket-enabled'], help="Websocket API server enabled, default: false") parser.add_argument("--os", default="unix", help="Specifies which OS, easily") if cli: