implicitly load either normal pickled weights or safetensors on loading the model
This commit is contained in:
parent
c09133d00f
commit
2cb465018b
|
@ -718,6 +718,8 @@ class Config(BaseConfig):
|
|||
|
||||
weights_format: str = "pth" # "pth" | "sft"
|
||||
|
||||
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
for i, model in enumerate(self.models):
|
||||
|
|
|
@ -12,7 +12,7 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
|||
|
||||
from ..models import get_models, get_model
|
||||
from ..utils import wrapper as ml
|
||||
from ..utils.io import torch_save, torch_load
|
||||
from ..utils.io import torch_save, torch_load, pick_path
|
||||
from ..models.lora import apply_lora, lora_load_state_dict
|
||||
|
||||
import torch
|
||||
|
@ -43,7 +43,7 @@ def load_engines(training=True):
|
|||
|
||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
load_path = cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}"
|
||||
load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
|
||||
# actually use the lora-specific checkpoint if available
|
||||
if cfg.lora is not None:
|
||||
|
@ -52,7 +52,7 @@ def load_engines(training=True):
|
|||
# to handle the issue of training with deepspeed, but inferencing with local
|
||||
if checkpoint_path.exists() and backend == "local":
|
||||
tag = open(checkpoint_path).read()
|
||||
checkpoint_path = checkpoint_path.parent / tag / f"state.{cfg.weights_format}"
|
||||
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
print("Checkpoint missing, but weights found:", load_path)
|
||||
|
@ -197,7 +197,7 @@ def load_engines(training=True):
|
|||
|
||||
# load lora weights if exists
|
||||
if cfg.lora is not None:
|
||||
lora_path = cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}"
|
||||
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
if lora_path.exists():
|
||||
print( "Loaded LoRA state dict:", lora_path )
|
||||
|
||||
|
|
|
@ -8,6 +8,16 @@ from safetensors.torch import save_file as sft_save
|
|||
def coerce_path( path ):
|
||||
return path if isinstance( path, Path ) else Path(path)
|
||||
|
||||
def pick_path( path, *suffixes ):
|
||||
suffixes = [*suffixes]
|
||||
|
||||
for suffix in suffixes:
|
||||
p = path.with_suffix( suffix )
|
||||
if p.exists():
|
||||
return p
|
||||
|
||||
return path
|
||||
|
||||
def is_dict_of( d, t ):
|
||||
if not isinstance( d, dict ):
|
||||
return False
|
||||
|
|
|
@ -33,7 +33,10 @@ def gradio_wrapper(inputs):
|
|||
def wrapped_function(*args, **kwargs):
|
||||
for i, key in enumerate(inputs):
|
||||
kwargs[key] = args[i]
|
||||
return fun(**kwargs)
|
||||
try:
|
||||
return fun(**kwargs)
|
||||
except Exception as e:
|
||||
raise gr.Error(str(e))
|
||||
return wrapped_function
|
||||
return decorated
|
||||
|
||||
|
@ -95,6 +98,9 @@ def init_tts(yaml=None, restart=False):
|
|||
|
||||
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
||||
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||
if not cfg.yaml_path:
|
||||
raise Exception("No YAML loaded.")
|
||||
|
||||
if kwargs.pop("dynamic-sampling", False):
|
||||
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
|
||||
kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
|
||||
|
@ -131,7 +137,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
|
||||
"""
|
||||
if not args.references:
|
||||
raise ValueError("No reference audio provided.")
|
||||
raise Exception("No reference audio provided.")
|
||||
"""
|
||||
|
||||
tts = init_tts()
|
||||
|
|
Loading…
Reference in New Issue
Block a user