From 2cb465018b127fda00b80e1d1d60141b69e5d99b Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 3 Aug 2024 23:34:18 -0500 Subject: [PATCH] implicitly load either normal pickled weights or safetensors on loading the model --- vall_e/config.py | 2 ++ vall_e/engines/__init__.py | 8 ++++---- vall_e/utils/io.py | 10 ++++++++++ vall_e/webui.py | 10 ++++++++-- 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index ea3f0cf..687da14 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -718,6 +718,8 @@ class Config(BaseConfig): weights_format: str = "pth" # "pth" | "sft" + supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"]) + @property def model(self): for i, model in enumerate(self.models): diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index f0c76d6..a3d2b32 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -12,7 +12,7 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine from ..models import get_models, get_model from ..utils import wrapper as ml -from ..utils.io import torch_save, torch_load +from ..utils.io import torch_save, torch_load, pick_path from ..models.lora import apply_lora, lora_load_state_dict import torch @@ -43,7 +43,7 @@ def load_engines(training=True): checkpoint_path = cfg.ckpt_dir / name / "latest" # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present - load_path = cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}" + load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) # actually use the lora-specific checkpoint if available if cfg.lora is not None: @@ -52,7 +52,7 @@ def load_engines(training=True): # to handle the issue of training with deepspeed, but inferencing with local if checkpoint_path.exists() and backend == "local": tag = open(checkpoint_path).read() - checkpoint_path = checkpoint_path.parent / tag / f"state.{cfg.weights_format}" + checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): print("Checkpoint missing, but weights found:", load_path) @@ -197,7 +197,7 @@ def load_engines(training=True): # load lora weights if exists if cfg.lora is not None: - lora_path = cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}" + lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) if lora_path.exists(): print( "Loaded LoRA state dict:", lora_path ) diff --git a/vall_e/utils/io.py b/vall_e/utils/io.py index 4dad4bf..afc2033 100644 --- a/vall_e/utils/io.py +++ b/vall_e/utils/io.py @@ -8,6 +8,16 @@ from safetensors.torch import save_file as sft_save def coerce_path( path ): return path if isinstance( path, Path ) else Path(path) +def pick_path( path, *suffixes ): + suffixes = [*suffixes] + + for suffix in suffixes: + p = path.with_suffix( suffix ) + if p.exists(): + return p + + return path + def is_dict_of( d, t ): if not isinstance( d, dict ): return False diff --git a/vall_e/webui.py b/vall_e/webui.py index 5b63fd4..8ca2919 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -33,7 +33,10 @@ def gradio_wrapper(inputs): def wrapped_function(*args, **kwargs): for i, key in enumerate(inputs): kwargs[key] = args[i] - return fun(**kwargs) + try: + return fun(**kwargs) + except Exception as e: + raise gr.Error(str(e)) return wrapped_function return decorated @@ -95,6 +98,9 @@ def init_tts(yaml=None, restart=False): @gradio_wrapper(inputs=layout["inference"]["inputs"].keys()) def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): + if not cfg.yaml_path: + raise Exception("No YAML loaded.") + if kwargs.pop("dynamic-sampling", False): kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0 kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR @@ -131,7 +137,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): """ if not args.references: - raise ValueError("No reference audio provided.") + raise Exception("No reference audio provided.") """ tts = init_tts()