websocket server: fix for model loading (just overriding args didn't do it after all...)

This commit is contained in:
ben_mkiv 2023-08-26 01:40:35 +02:00
parent 00b173857d
commit 578a5bcadd

View File

@ -4,7 +4,7 @@ from threading import Thread
from websockets.server import serve
from utils import generate, get_autoregressive_models, get_voice_list, args
from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer
# this is a not so nice workaround to set values to None if their string value is "None"
def replaceNoneStringWithNone(message):
@ -18,19 +18,18 @@ def replaceNoneStringWithNone(message):
async def _handle_generate(websocket, message):
global args
# update args parameters which control the model settings
if message.get('autoregressive_model'):
args.autoregressive_model = message['autoregressive_model']
update_autoregressive_model(message['autoregressive_model'])
if message.get('diffusion_model'):
args.diffusion_model = message['diffusion_model']
update_diffusion_model(message['diffusion_model'])
if message.get('tokenizer_json'):
args.tokenizer_json = message['tokenizer_json']
update_tokenizer(message['tokenizer_json'])
if message.get('sample_batch_size'):
global args
args.sample_batch_size = message['sample_batch_size']
message['result'] = generate(**message)