added websocket server arguments to enabled it (now disabled by default) and to specify the address/port to listen on

This commit is contained in:
ben_mkiv 2023-08-26 17:38:58 +02:00
parent 6f0f148782
commit b72f2216bf
3 changed files with 13 additions and 3 deletions

View File

@ -4,7 +4,7 @@ from threading import Thread
from websockets.server import serve from websockets.server import serve
from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer, tts
# this is a not so nice workaround to set values to None if their string value is "None" # this is a not so nice workaround to set values to None if their string value is "None"
def replaceNoneStringWithNone(message): def replaceNoneStringWithNone(message):
@ -70,7 +70,7 @@ async def _handle_connection(websocket, path):
async def _run(host: str, port: int): async def _run(host: str, port: int):
print("websocket: server started") print(f"websocket: server started on ws://{host}:{port}")
async with serve(_handle_connection, host, port, ping_interval=None): async with serve(_handle_connection, host, port, ping_interval=None):
await asyncio.Future() # run forever await asyncio.Future() # run forever

View File

@ -26,7 +26,9 @@ if __name__ == "__main__":
if not args.defer_tts_load: if not args.defer_tts_load:
tts = load_tts() tts = load_tts()
start_websocket_server('127.0.0.1', 8069) if args.websocket_enabled:
start_websocket_server(args.websocket_listen_address, args.websocket_listen_port)
webui.block_thread() webui.block_thread()
elif __name__ == "main": elif __name__ == "main":
from fastapi import FastAPI from fastapi import FastAPI

View File

@ -3300,6 +3300,10 @@ def setup_args(cli=False):
'training-default-halfp': False, 'training-default-halfp': False,
'training-default-bnb': True, 'training-default-bnb': True,
'websocket-listen-address': "127.0.0.1",
'websocket-listen-port': 8069,
'websocket-enabled': False
} }
if os.path.isfile('./config/exec.json'): if os.path.isfile('./config/exec.json'):
@ -3352,6 +3356,10 @@ def setup_args(cli=False):
parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp") parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp")
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb") parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
parser.add_argument("--websocket-listen-port", type=int, default=default_arguments['websocket-listen-port'], help="Websocket server listen port, default: 8069")
parser.add_argument("--websocket-listen-address", default=default_arguments['websocket-listen-address'], help="Websocket server listen address, default: 127.0.0.1")
parser.add_argument("--websocket-enabled", action='store_true', default=default_arguments['websocket-enabled'], help="Websocket API server enabled, default: false")
parser.add_argument("--os", default="unix", help="Specifies which OS, easily") parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
if cli: if cli: