@ -4,8 +4,7 @@ from threading import Thread
from websockets . server import serve
from utils import generate , get_autoregressive_models , get_voice_list
from utils import generate , get_autoregressive_models , get_voice_list , args
# this is a not so nice workaround to set values to None if their string value is "None"
def replaceNoneStringWithNone ( message ) :
@ -19,6 +18,21 @@ def replaceNoneStringWithNone(message):
async def _handle_generate ( websocket , message ) :
global args
# update args parameters which control the model settings
if message . get ( ' autoregressive_model ' ) :
args . autoregressive_model = message [ ' autoregressive_model ' ]
if message . get ( ' diffusion_model ' ) :
args . diffusion_model = message [ ' diffusion_model ' ]
if message . get ( ' tokenizer_json ' ) :
args . tokenizer_json = message [ ' tokenizer_json ' ]
if message . get ( ' sample_batch_size ' ) :
args . sample_batch_size = message [ ' sample_batch_size ' ]
message [ ' result ' ] = generate ( * * message )
await websocket . send ( json . dumps ( replaceNoneStringWithNone ( message ) ) )