From c83670c38c4dde675e6fb62eaffdc97f7e857bbb Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 3 Nov 2024 19:19:15 -0600 Subject: [PATCH] Windows specific fixes (to-do: find libespeak-ng.dll automatically because it cannot be trusted to do it by default) --- setup.py | 2 +- vall_e/config.py | 8 ++++++++ vall_e/demo.py | 6 ++++++ vall_e/engines/base.py | 8 ++++++-- vall_e/models/arch/llama.py | 2 +- vall_e/webui.py | 11 ++++++++++- 6 files changed, 32 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index ba14d79..f4487fe 100755 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ setup( packages=find_packages(), install_requires=( # training backends - ["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else []) + ["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else ["psutil"]) + [ # logging niceties "coloredlogs>=15.0.1", diff --git a/vall_e/config.py b/vall_e/config.py index 401c2a5..fb32f21 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -948,6 +948,14 @@ class Config(BaseConfig): _logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}") del model["interleave"] + if "p_rvq_levels" in model["experimental"]: + model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"] + del model["experimental"]["p_rvq_levels"] + + if "p_len_train" in model["experimental"]: + model["experimental"]["len_train_p"] = model["experimental"]["p_len_train"] + del model["experimental"]["p_len_train"] + self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ] self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ] diff --git a/vall_e/demo.py b/vall_e/demo.py index ef74687..ccb5d8c 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -156,6 +156,12 @@ def main(): comparison_kwargs["disabled"]["layer_skip"] = False comparison_kwargs["enabled"]["layer_skip"] = True + elif args.comparison == "refine-on-stop": + comparison_kwargs["suffix"] = "refine-on-stop" + comparison_kwargs["titles"] = [f"Without Ro", "With Ro"] + + comparison_kwargs["disabled"]["refine_on_stop"] = False + comparison_kwargs["enabled"]["refine_on_stop"] = True elif args.comparison == "ar-temp": current_temp = args.ar_temp other_temp = 1.0 diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 03f5233..ce065c2 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -48,8 +48,12 @@ from ..utils import wrapper as ml _logger = logging.getLogger(__name__) -if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1: - init_distributed(torch.distributed.init_process_group) +# windows throws an error here +try: + if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1: + init_distributed(torch.distributed.init_process_group) +except Exception as e: + pass # A very naive engine implementation using barebones PyTorch class Engine(): diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index c5ad089..25755a1 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -5,7 +5,7 @@ import torch import logging import random -from typing import Literal, overload, Optional, Tuple, Union, List, Unpack +from typing import Literal, overload, Optional, Tuple, Union, List from torch import Tensor, nn from transformers.cache_utils import Cache diff --git a/vall_e/webui.py b/vall_e/webui.py index d76f2a6..b81c706 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -1,4 +1,5 @@ import os +import sys import re import math import argparse @@ -22,6 +23,8 @@ from .emb.qnt import decode_to_wave from .data import get_lang_symmap, get_random_prompt +is_windows = sys.platform.startswith("win") + tts = None layout = {} @@ -68,6 +71,9 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data continue configs.append( sft ) + if is_windows: + configs = [ str(p) for p in configs ] + return configs def get_dtypes(): @@ -199,7 +205,10 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--refine-on-stop", action="store_true") args, unknown = parser.parse_known_args() - tmp = tempfile.NamedTemporaryFile(suffix='.wav') + if is_windows: + tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False) + else: + tmp = tempfile.NamedTemporaryFile(suffix='.wav') """ if not args.references: