From a17078a792f79711a082528f0b09be2604b5c1f1 Mon Sep 17 00:00:00 2001 From: ben_mkiv Date: Wed, 23 Aug 2023 19:20:54 +0200 Subject: [PATCH] Revert "added parameter to specify the autoregressive_model (tho it still loads the default model first, and then loads the target model, which seems to be because TTS loading just loads whatever is set in the settings first)" This reverts commit d1dbe3e464160d9dc16293850c08b265a43605fa. --- src/api/websocket_server.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/api/websocket_server.py b/src/api/websocket_server.py index 7695fbc..87fb2d9 100644 --- a/src/api/websocket_server.py +++ b/src/api/websocket_server.py @@ -4,7 +4,7 @@ from threading import Thread from websockets.server import serve -from utils import generate, get_autoregressive_models, get_voice_list +from utils import generate, get_autoregressive_models, get_voice_list, tts, args # this is a not so nice workaround to set values to None if their string value is "None" @@ -19,6 +19,21 @@ def replaceNoneStringWithNone(message): async def _handle_generate(websocket, message): + global args + + # update args parameters which control the model settings + if message.get('autoregressive_model'): + args.autoregressive_model = message['autoregressive_model'] + + if message.get('diffusion_model'): + args.diffusion_model = message['diffusion_model'] + + if message.get('tokenizer_json'): + args.tokenizer_json = message['tokenizer_json'] + + if message.get('sample_batch_size'): + args.sample_batch_size = message['sample_batch_size'] + message['result'] = generate(**message) await websocket.send(json.dumps(replaceNoneStringWithNone(message)))