implicitly load either normal pickled weights or safetensors on loading the model
This commit is contained in:
parent
c09133d00f
commit
2cb465018b
@ -718,6 +718,8 @@ class Config(BaseConfig):
|
|||||||
|
|
||||||
weights_format: str = "pth" # "pth" | "sft"
|
weights_format: str = "pth" # "pth" | "sft"
|
||||||
|
|
||||||
|
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self):
|
def model(self):
|
||||||
for i, model in enumerate(self.models):
|
for i, model in enumerate(self.models):
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
|||||||
|
|
||||||
from ..models import get_models, get_model
|
from ..models import get_models, get_model
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
from ..utils.io import torch_save, torch_load
|
from ..utils.io import torch_save, torch_load, pick_path
|
||||||
from ..models.lora import apply_lora, lora_load_state_dict
|
from ..models.lora import apply_lora, lora_load_state_dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -43,7 +43,7 @@ def load_engines(training=True):
|
|||||||
|
|
||||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||||
load_path = cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}"
|
load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||||
|
|
||||||
# actually use the lora-specific checkpoint if available
|
# actually use the lora-specific checkpoint if available
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
@ -52,7 +52,7 @@ def load_engines(training=True):
|
|||||||
# to handle the issue of training with deepspeed, but inferencing with local
|
# to handle the issue of training with deepspeed, but inferencing with local
|
||||||
if checkpoint_path.exists() and backend == "local":
|
if checkpoint_path.exists() and backend == "local":
|
||||||
tag = open(checkpoint_path).read()
|
tag = open(checkpoint_path).read()
|
||||||
checkpoint_path = checkpoint_path.parent / tag / f"state.{cfg.weights_format}"
|
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||||
|
|
||||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||||
print("Checkpoint missing, but weights found:", load_path)
|
print("Checkpoint missing, but weights found:", load_path)
|
||||||
@ -197,7 +197,7 @@ def load_engines(training=True):
|
|||||||
|
|
||||||
# load lora weights if exists
|
# load lora weights if exists
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
lora_path = cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}"
|
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||||
if lora_path.exists():
|
if lora_path.exists():
|
||||||
print( "Loaded LoRA state dict:", lora_path )
|
print( "Loaded LoRA state dict:", lora_path )
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,16 @@ from safetensors.torch import save_file as sft_save
|
|||||||
def coerce_path( path ):
|
def coerce_path( path ):
|
||||||
return path if isinstance( path, Path ) else Path(path)
|
return path if isinstance( path, Path ) else Path(path)
|
||||||
|
|
||||||
|
def pick_path( path, *suffixes ):
|
||||||
|
suffixes = [*suffixes]
|
||||||
|
|
||||||
|
for suffix in suffixes:
|
||||||
|
p = path.with_suffix( suffix )
|
||||||
|
if p.exists():
|
||||||
|
return p
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
def is_dict_of( d, t ):
|
def is_dict_of( d, t ):
|
||||||
if not isinstance( d, dict ):
|
if not isinstance( d, dict ):
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -33,7 +33,10 @@ def gradio_wrapper(inputs):
|
|||||||
def wrapped_function(*args, **kwargs):
|
def wrapped_function(*args, **kwargs):
|
||||||
for i, key in enumerate(inputs):
|
for i, key in enumerate(inputs):
|
||||||
kwargs[key] = args[i]
|
kwargs[key] = args[i]
|
||||||
return fun(**kwargs)
|
try:
|
||||||
|
return fun(**kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
raise gr.Error(str(e))
|
||||||
return wrapped_function
|
return wrapped_function
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
@ -95,6 +98,9 @@ def init_tts(yaml=None, restart=False):
|
|||||||
|
|
||||||
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
||||||
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
|
if not cfg.yaml_path:
|
||||||
|
raise Exception("No YAML loaded.")
|
||||||
|
|
||||||
if kwargs.pop("dynamic-sampling", False):
|
if kwargs.pop("dynamic-sampling", False):
|
||||||
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
|
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
|
||||||
kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
|
kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
|
||||||
@ -131,7 +137,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if not args.references:
|
if not args.references:
|
||||||
raise ValueError("No reference audio provided.")
|
raise Exception("No reference audio provided.")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tts = init_tts()
|
tts = init_tts()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user