forked from mrq/ai-voice-cloning
Merge pull request 'Websocket fixes / additions' (#350) from ben_mkiv/ai-voice-cloning:master into master
Reviewed-on: mrq/ai-voice-cloning#350
This commit is contained in:
commit
7110b878b7
|
@ -4,7 +4,7 @@ from threading import Thread
|
|||
|
||||
from websockets.server import serve
|
||||
|
||||
from utils import generate, get_autoregressive_models, get_voice_list, args
|
||||
from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer
|
||||
|
||||
# this is a not so nice workaround to set values to None if their string value is "None"
|
||||
def replaceNoneStringWithNone(message):
|
||||
|
@ -18,19 +18,18 @@ def replaceNoneStringWithNone(message):
|
|||
|
||||
|
||||
async def _handle_generate(websocket, message):
|
||||
global args
|
||||
|
||||
# update args parameters which control the model settings
|
||||
if message.get('autoregressive_model'):
|
||||
args.autoregressive_model = message['autoregressive_model']
|
||||
update_autoregressive_model(message['autoregressive_model'])
|
||||
|
||||
if message.get('diffusion_model'):
|
||||
args.diffusion_model = message['diffusion_model']
|
||||
update_diffusion_model(message['diffusion_model'])
|
||||
|
||||
if message.get('tokenizer_json'):
|
||||
args.tokenizer_json = message['tokenizer_json']
|
||||
update_tokenizer(message['tokenizer_json'])
|
||||
|
||||
if message.get('sample_batch_size'):
|
||||
global args
|
||||
args.sample_batch_size = message['sample_batch_size']
|
||||
|
||||
message['result'] = generate(**message)
|
||||
|
@ -71,7 +70,7 @@ async def _handle_connection(websocket, path):
|
|||
|
||||
|
||||
async def _run(host: str, port: int):
|
||||
print("websocket: server started")
|
||||
print(f"websocket: server started on ws://{host}:{port}")
|
||||
|
||||
async with serve(_handle_connection, host, port, ping_interval=None):
|
||||
await asyncio.Future() # run forever
|
||||
|
|
|
@ -26,7 +26,9 @@ if __name__ == "__main__":
|
|||
if not args.defer_tts_load:
|
||||
tts = load_tts()
|
||||
|
||||
start_websocket_server('127.0.0.1', 8069)
|
||||
if args.websocket_enabled:
|
||||
start_websocket_server(args.websocket_listen_address, args.websocket_listen_port)
|
||||
|
||||
webui.block_thread()
|
||||
elif __name__ == "main":
|
||||
from fastapi import FastAPI
|
||||
|
|
|
@ -3303,6 +3303,10 @@ def setup_args(cli=False):
|
|||
|
||||
'training-default-halfp': False,
|
||||
'training-default-bnb': True,
|
||||
|
||||
'websocket-listen-address': "127.0.0.1",
|
||||
'websocket-listen-port': 8069,
|
||||
'websocket-enabled': False
|
||||
}
|
||||
|
||||
if os.path.isfile('./config/exec.json'):
|
||||
|
@ -3356,6 +3360,10 @@ def setup_args(cli=False):
|
|||
parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp")
|
||||
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
|
||||
|
||||
parser.add_argument("--websocket-listen-port", type=int, default=default_arguments['websocket-listen-port'], help="Websocket server listen port, default: 8069")
|
||||
parser.add_argument("--websocket-listen-address", default=default_arguments['websocket-listen-address'], help="Websocket server listen address, default: 127.0.0.1")
|
||||
parser.add_argument("--websocket-enabled", action='store_true', default=default_arguments['websocket-enabled'], help="Websocket API server enabled, default: false")
|
||||
|
||||
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
|
||||
if cli:
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user