forked from mrq/ai-voice-cloning
websocket server: fix for model loading (just overriding args didn't do it after all...)
This commit is contained in:
parent
00b173857d
commit
578a5bcadd
|
@ -4,7 +4,7 @@ from threading import Thread
|
||||||
|
|
||||||
from websockets.server import serve
|
from websockets.server import serve
|
||||||
|
|
||||||
from utils import generate, get_autoregressive_models, get_voice_list, args
|
from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer
|
||||||
|
|
||||||
# this is a not so nice workaround to set values to None if their string value is "None"
|
# this is a not so nice workaround to set values to None if their string value is "None"
|
||||||
def replaceNoneStringWithNone(message):
|
def replaceNoneStringWithNone(message):
|
||||||
|
@ -18,19 +18,18 @@ def replaceNoneStringWithNone(message):
|
||||||
|
|
||||||
|
|
||||||
async def _handle_generate(websocket, message):
|
async def _handle_generate(websocket, message):
|
||||||
global args
|
|
||||||
|
|
||||||
# update args parameters which control the model settings
|
# update args parameters which control the model settings
|
||||||
if message.get('autoregressive_model'):
|
if message.get('autoregressive_model'):
|
||||||
args.autoregressive_model = message['autoregressive_model']
|
update_autoregressive_model(message['autoregressive_model'])
|
||||||
|
|
||||||
if message.get('diffusion_model'):
|
if message.get('diffusion_model'):
|
||||||
args.diffusion_model = message['diffusion_model']
|
update_diffusion_model(message['diffusion_model'])
|
||||||
|
|
||||||
if message.get('tokenizer_json'):
|
if message.get('tokenizer_json'):
|
||||||
args.tokenizer_json = message['tokenizer_json']
|
update_tokenizer(message['tokenizer_json'])
|
||||||
|
|
||||||
if message.get('sample_batch_size'):
|
if message.get('sample_batch_size'):
|
||||||
|
global args
|
||||||
args.sample_batch_size = message['sample_batch_size']
|
args.sample_batch_size = message['sample_batch_size']
|
||||||
|
|
||||||
message['result'] = generate(**message)
|
message['result'] = generate(**message)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user