implicitly load either normal pickled weights or safetensors on loading the model

This commit is contained in:
mrq 2024-08-03 23:34:18 -05:00
parent c09133d00f
commit 2cb465018b
4 changed files with 24 additions and 6 deletions

View File

@ -718,6 +718,8 @@ class Config(BaseConfig):
weights_format: str = "pth" # "pth" | "sft"
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
@property
def model(self):
for i, model in enumerate(self.models):

View File

@ -12,7 +12,7 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
from ..models import get_models, get_model
from ..utils import wrapper as ml
from ..utils.io import torch_save, torch_load
from ..utils.io import torch_save, torch_load, pick_path
from ..models.lora import apply_lora, lora_load_state_dict
import torch
@ -43,7 +43,7 @@ def load_engines(training=True):
checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}"
load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
# actually use the lora-specific checkpoint if available
if cfg.lora is not None:
@ -52,7 +52,7 @@ def load_engines(training=True):
# to handle the issue of training with deepspeed, but inferencing with local
if checkpoint_path.exists() and backend == "local":
tag = open(checkpoint_path).read()
checkpoint_path = checkpoint_path.parent / tag / f"state.{cfg.weights_format}"
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found:", load_path)
@ -197,7 +197,7 @@ def load_engines(training=True):
# load lora weights if exists
if cfg.lora is not None:
lora_path = cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}"
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
if lora_path.exists():
print( "Loaded LoRA state dict:", lora_path )

View File

@ -8,6 +8,16 @@ from safetensors.torch import save_file as sft_save
def coerce_path( path ):
return path if isinstance( path, Path ) else Path(path)
def pick_path( path, *suffixes ):
suffixes = [*suffixes]
for suffix in suffixes:
p = path.with_suffix( suffix )
if p.exists():
return p
return path
def is_dict_of( d, t ):
if not isinstance( d, dict ):
return False

View File

@ -33,7 +33,10 @@ def gradio_wrapper(inputs):
def wrapped_function(*args, **kwargs):
for i, key in enumerate(inputs):
kwargs[key] = args[i]
return fun(**kwargs)
try:
return fun(**kwargs)
except Exception as e:
raise gr.Error(str(e))
return wrapped_function
return decorated
@ -95,6 +98,9 @@ def init_tts(yaml=None, restart=False):
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if not cfg.yaml_path:
raise Exception("No YAML loaded.")
if kwargs.pop("dynamic-sampling", False):
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
@ -131,7 +137,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
"""
if not args.references:
raise ValueError("No reference audio provided.")
raise Exception("No reference audio provided.")
"""
tts = init_tts()