Websocket server, override args parameters for model settings (squashed)
Revert "favor existing arguments from parameters (kwargs) over global (args)"
This reverts commit 89102347a9
.
args are now updated in the websocket server
This commit is contained in:
parent
5d73d9e71c
commit
ce24ba41e2
|
@ -4,8 +4,7 @@ from threading import Thread
|
|||
|
||||
from websockets.server import serve
|
||||
|
||||
from utils import generate, get_autoregressive_models, get_voice_list
|
||||
|
||||
from utils import generate, get_autoregressive_models, get_voice_list, args
|
||||
|
||||
# this is a not so nice workaround to set values to None if their string value is "None"
|
||||
def replaceNoneStringWithNone(message):
|
||||
|
@ -19,6 +18,21 @@ def replaceNoneStringWithNone(message):
|
|||
|
||||
|
||||
async def _handle_generate(websocket, message):
|
||||
global args
|
||||
|
||||
# update args parameters which control the model settings
|
||||
if message.get('autoregressive_model'):
|
||||
args.autoregressive_model = message['autoregressive_model']
|
||||
|
||||
if message.get('diffusion_model'):
|
||||
args.diffusion_model = message['diffusion_model']
|
||||
|
||||
if message.get('tokenizer_json'):
|
||||
args.tokenizer_json = message['tokenizer_json']
|
||||
|
||||
if message.get('sample_batch_size'):
|
||||
args.sample_batch_size = message['sample_batch_size']
|
||||
|
||||
message['result'] = generate(**message)
|
||||
await websocket.send(json.dumps(replaceNoneStringWithNone(message)))
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user