import asyncio
import json
from threading import Thread

from websockets.server import serve

from utils import generate, get_autoregressive_models, get_voice_list, args

# this is a not so nice workaround to set values to None if their string value is "None"
def replaceNoneStringWithNone(message):
    ignore_fields = ['text']  # list of fields which CAN have "None" as literal String value

    for member in message:
        if message[member] == 'None' and member not in ignore_fields:
            message[member] = None

    return message


async def _handle_generate(websocket, message):
    global args

    # update args parameters which control the model settings
    if message.get('autoregressive_model'):
        args.autoregressive_model = message['autoregressive_model']

    if message.get('diffusion_model'):
        args.diffusion_model = message['diffusion_model']

    if message.get('tokenizer_json'):
        args.tokenizer_json = message['tokenizer_json']

    if message.get('sample_batch_size'):
        args.sample_batch_size = message['sample_batch_size']

    message['result'] = generate(**message)
    await websocket.send(json.dumps(replaceNoneStringWithNone(message)))


async def _handle_get_autoregressive_models(websocket, message):
    message['result'] = get_autoregressive_models()
    await websocket.send(json.dumps(replaceNoneStringWithNone(message)))


async def _handle_get_voice_list(websocket, message):
    message['result'] = get_voice_list()
    await websocket.send(json.dumps(replaceNoneStringWithNone(message)))


async def _handle_message(websocket, message):
    message = replaceNoneStringWithNone(message)

    if message.get('action') and message['action'] == 'generate':
        await _handle_generate(websocket, message)
    elif message.get('action') and message['action'] == 'get_voices':
        await _handle_get_voice_list(websocket, message)
    elif message.get('action') and message['action'] == 'get_autoregressive_models':
        await _handle_get_autoregressive_models(websocket, message)
    else:
        print("websocket: undhandled message: " + message)


async def _handle_connection(websocket, path):
    print("websocket: client connected")

    async for message in websocket:
        try:
            await _handle_message(websocket, json.loads(message))
        except ValueError:
            print("websocket: malformed json received")


async def _run(host: str, port: int):
    print("websocket: server started")

    async with serve(_handle_connection, host, port, ping_interval=None):
        await asyncio.Future()  # run forever


def _run_server(listen_address: str, port: int):
    asyncio.run(_run(host=listen_address, port=port))


def start_websocket_server(listen_address: str, port: int):
    Thread(target=_run_server, args=[listen_address, port], daemon=True).start()