diff --git a/src/api/websocket_server.py b/src/api/websocket_server.py index 8e67447..85e21cd 100644 --- a/src/api/websocket_server.py +++ b/src/api/websocket_server.py @@ -4,7 +4,7 @@ from threading import Thread from websockets.server import serve -from utils import generate, get_autoregressive_models, get_voice_list, args +from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer # this is a not so nice workaround to set values to None if their string value is "None" def replaceNoneStringWithNone(message): @@ -18,19 +18,18 @@ def replaceNoneStringWithNone(message): async def _handle_generate(websocket, message): - global args - # update args parameters which control the model settings if message.get('autoregressive_model'): - args.autoregressive_model = message['autoregressive_model'] + update_autoregressive_model(message['autoregressive_model']) if message.get('diffusion_model'): - args.diffusion_model = message['diffusion_model'] + update_diffusion_model(message['diffusion_model']) if message.get('tokenizer_json'): - args.tokenizer_json = message['tokenizer_json'] + update_tokenizer(message['tokenizer_json']) if message.get('sample_batch_size'): + global args args.sample_batch_size = message['sample_batch_size'] message['result'] = generate(**message) @@ -71,7 +70,7 @@ async def _handle_connection(websocket, path): async def _run(host: str, port: int): - print("websocket: server started") + print(f"websocket: server started on ws://{host}:{port}") async with serve(_handle_connection, host, port, ping_interval=None): await asyncio.Future() # run forever diff --git a/src/main.py b/src/main.py index 418d616..dbb8d26 100755 --- a/src/main.py +++ b/src/main.py @@ -26,7 +26,9 @@ if __name__ == "__main__": if not args.defer_tts_load: tts = load_tts() - start_websocket_server('127.0.0.1', 8069) + if args.websocket_enabled: + start_websocket_server(args.websocket_listen_address, args.websocket_listen_port) + webui.block_thread() elif __name__ == "main": from fastapi import FastAPI diff --git a/src/utils.py b/src/utils.py index 85d9701..6457b77 100755 --- a/src/utils.py +++ b/src/utils.py @@ -3303,6 +3303,10 @@ def setup_args(cli=False): 'training-default-halfp': False, 'training-default-bnb': True, + + 'websocket-listen-address': "127.0.0.1", + 'websocket-listen-port': 8069, + 'websocket-enabled': False } if os.path.isfile('./config/exec.json'): @@ -3355,6 +3359,10 @@ def setup_args(cli=False): parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp") parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb") + + parser.add_argument("--websocket-listen-port", type=int, default=default_arguments['websocket-listen-port'], help="Websocket server listen port, default: 8069") + parser.add_argument("--websocket-listen-address", default=default_arguments['websocket-listen-address'], help="Websocket server listen address, default: 127.0.0.1") + parser.add_argument("--websocket-enabled", action='store_true', default=default_arguments['websocket-enabled'], help="Websocket API server enabled, default: false") parser.add_argument("--os", default="unix", help="Specifies which OS, easily") if cli: