From 578a5bcadd77df6de0881532086fe116f16ea3b4 Mon Sep 17 00:00:00 2001 From: ben_mkiv Date: Sat, 26 Aug 2023 01:40:35 +0200 Subject: [PATCH 1/3] websocket server: fix for model loading (just overriding args didn't do it after all...) --- src/api/websocket_server.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/api/websocket_server.py b/src/api/websocket_server.py index 8e67447..e9facbb 100644 --- a/src/api/websocket_server.py +++ b/src/api/websocket_server.py @@ -4,7 +4,7 @@ from threading import Thread from websockets.server import serve -from utils import generate, get_autoregressive_models, get_voice_list, args +from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer # this is a not so nice workaround to set values to None if their string value is "None" def replaceNoneStringWithNone(message): @@ -18,19 +18,18 @@ def replaceNoneStringWithNone(message): async def _handle_generate(websocket, message): - global args - # update args parameters which control the model settings if message.get('autoregressive_model'): - args.autoregressive_model = message['autoregressive_model'] + update_autoregressive_model(message['autoregressive_model']) if message.get('diffusion_model'): - args.diffusion_model = message['diffusion_model'] + update_diffusion_model(message['diffusion_model']) if message.get('tokenizer_json'): - args.tokenizer_json = message['tokenizer_json'] + update_tokenizer(message['tokenizer_json']) if message.get('sample_batch_size'): + global args args.sample_batch_size = message['sample_batch_size'] message['result'] = generate(**message) From 6f0f1487823ff1b43d5a2b42ef4d118623e6ca76 Mon Sep 17 00:00:00 2001 From: ben_mkiv Date: Sat, 26 Aug 2023 01:40:35 +0200 Subject: [PATCH 2/3] websocket server: fix for model loading (just overriding args didn't do it after all...) --- src/api/websocket_server.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/api/websocket_server.py b/src/api/websocket_server.py index 8e67447..e9facbb 100644 --- a/src/api/websocket_server.py +++ b/src/api/websocket_server.py @@ -4,7 +4,7 @@ from threading import Thread from websockets.server import serve -from utils import generate, get_autoregressive_models, get_voice_list, args +from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer # this is a not so nice workaround to set values to None if their string value is "None" def replaceNoneStringWithNone(message): @@ -18,19 +18,18 @@ def replaceNoneStringWithNone(message): async def _handle_generate(websocket, message): - global args - # update args parameters which control the model settings if message.get('autoregressive_model'): - args.autoregressive_model = message['autoregressive_model'] + update_autoregressive_model(message['autoregressive_model']) if message.get('diffusion_model'): - args.diffusion_model = message['diffusion_model'] + update_diffusion_model(message['diffusion_model']) if message.get('tokenizer_json'): - args.tokenizer_json = message['tokenizer_json'] + update_tokenizer(message['tokenizer_json']) if message.get('sample_batch_size'): + global args args.sample_batch_size = message['sample_batch_size'] message['result'] = generate(**message) From b72f2216bf466ca4de484dd19736dcefc52e0eb3 Mon Sep 17 00:00:00 2001 From: ben_mkiv Date: Sat, 26 Aug 2023 17:38:58 +0200 Subject: [PATCH 3/3] added websocket server arguments to enabled it (now disabled by default) and to specify the address/port to listen on --- src/api/websocket_server.py | 4 ++-- src/main.py | 4 +++- src/utils.py | 8 ++++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/api/websocket_server.py b/src/api/websocket_server.py index e9facbb..d11b387 100644 --- a/src/api/websocket_server.py +++ b/src/api/websocket_server.py @@ -4,7 +4,7 @@ from threading import Thread from websockets.server import serve -from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer +from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer, tts # this is a not so nice workaround to set values to None if their string value is "None" def replaceNoneStringWithNone(message): @@ -70,7 +70,7 @@ async def _handle_connection(websocket, path): async def _run(host: str, port: int): - print("websocket: server started") + print(f"websocket: server started on ws://{host}:{port}") async with serve(_handle_connection, host, port, ping_interval=None): await asyncio.Future() # run forever diff --git a/src/main.py b/src/main.py index 418d616..dbb8d26 100755 --- a/src/main.py +++ b/src/main.py @@ -26,7 +26,9 @@ if __name__ == "__main__": if not args.defer_tts_load: tts = load_tts() - start_websocket_server('127.0.0.1', 8069) + if args.websocket_enabled: + start_websocket_server(args.websocket_listen_address, args.websocket_listen_port) + webui.block_thread() elif __name__ == "main": from fastapi import FastAPI diff --git a/src/utils.py b/src/utils.py index 00d9c24..d4fbb3e 100755 --- a/src/utils.py +++ b/src/utils.py @@ -3300,6 +3300,10 @@ def setup_args(cli=False): 'training-default-halfp': False, 'training-default-bnb': True, + + 'websocket-listen-address': "127.0.0.1", + 'websocket-listen-port': 8069, + 'websocket-enabled': False } if os.path.isfile('./config/exec.json'): @@ -3352,6 +3356,10 @@ def setup_args(cli=False): parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp") parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb") + + parser.add_argument("--websocket-listen-port", type=int, default=default_arguments['websocket-listen-port'], help="Websocket server listen port, default: 8069") + parser.add_argument("--websocket-listen-address", default=default_arguments['websocket-listen-address'], help="Websocket server listen address, default: 127.0.0.1") + parser.add_argument("--websocket-enabled", action='store_true', default=default_arguments['websocket-enabled'], help="Websocket API server enabled, default: false") parser.add_argument("--os", default="unix", help="Specifies which OS, easily") if cli: