added option to load from a model state dict directly instead of a yaml (to-do: do this for LoRAs too), automatically download the default model if none is provided

This commit is contained in:
mrq 2024-10-25 22:15:15 -05:00
parent a96f5aee32
commit ccf71dc1b6
5 changed files with 128 additions and 63 deletions

View File

@ -21,6 +21,7 @@ from functools import cached_property
from pathlib import Path
from .utils.distributed import world_size
from .utils.io import torch_load
from .utils import set_seed, prune_missing
@dataclass()
@ -30,7 +31,13 @@ class BaseConfig:
@property
def cfg_path(self):
return Path(self.yaml_path.parent) if self.yaml_path is not None else Path(__file__).parent.parent / "data"
if self.yaml_path:
return Path(self.yaml_path.parent)
if self.model_path:
return Path(self.model_path.parent)
return Path(__file__).parent.parent / "data"
@property
def rel_path(self):
@ -93,8 +100,6 @@ class BaseConfig:
def prune_missing( cls, yaml ):
default = cls(**{})
default.format()
#default = json.loads(default.dumps())
yaml, missing = prune_missing( source=default, dest=yaml )
if missing:
_logger.warning(f'Missing keys in YAML: {missing}')
@ -108,6 +113,17 @@ class BaseConfig:
state = cls.prune_missing( state )
return cls(**state)
@classmethod
def from_model( cls, model_path ):
if not model_path.exists():
raise Exception(f'Model path does not exist: {model_path}')
# load state dict and copy its stored model config
state_dict = torch_load( model_path )["config"]
state = { "models": [ state_dict ], "trainer": { "load_state_dict": True }, "model_path": model_path }
return cls(**state)
@classmethod
def from_cli(cls, args=sys.argv):
# legacy support for yaml=`` format
@ -117,8 +133,12 @@ class BaseConfig:
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
args, unknown = parser.parse_known_args(args=args)
if args.model:
return cls.from_model( args.model )
if args.yaml:
return cls.from_yaml( args.yaml )
@ -807,11 +827,15 @@ class Config(BaseConfig):
return diskcache.Cache(self.cache_dir).memoize
return lambda: lambda x: x
# I don't remember why this is needed
# this gets called from vall_e.inference
def load_yaml( self, config_path ):
tmp = Config.from_yaml( config_path )
self.__dict__.update(tmp.__dict__)
def load_model( self, config_path ):
tmp = Config.from_model( config_path )
self.__dict__.update(tmp.__dict__)
def load_hdf5( self, write=False ):
if hasattr(self, 'hdf5'):
self.hdf5.close()
@ -870,7 +894,27 @@ class Config(BaseConfig):
if isinstance(self.optimizations, type):
self.optimizations = dict()
self.dataset = Dataset(**self.dataset)
if isinstance( self.dataset, dict ):
self.dataset = Dataset(**self.dataset)
if isinstance( self.hyperparameters, dict ):
self.hyperparameters = Hyperparameters(**self.hyperparameters)
if isinstance( self.evaluation, dict ):
self.evaluation = Evaluation(**self.evaluation)
if isinstance( self.trainer, dict ):
self.trainer = Trainer(**self.trainer)
if isinstance( self.trainer.deepspeed, dict ):
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
if isinstance( self.inference, dict ):
self.inference = Inference(**self.inference)
if isinstance( self.optimizations, dict ):
self.optimizations = Optimizations(**self.optimizations)
# convert to expanded paths
self.dataset.training = [ self.expand(dir) for dir in self.dataset.training ]
self.dataset.validation = [ self.expand(dir) for dir in self.dataset.validation ]
@ -906,28 +950,15 @@ class Config(BaseConfig):
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
self.models = [ Model(**model) for model in self.models ]
self.loras = [ LoRA(**lora) for lora in self.loras ]
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]
if not self.models:
self.models = [ Model() ]
for model in self.models:
if not isinstance( model.experimental, dict ):
continue
model.experimental = ModelExperimentalSettings(**model.experimental)
self.hyperparameters = Hyperparameters(**self.hyperparameters)
self.evaluation = Evaluation(**self.evaluation)
self.trainer = Trainer(**self.trainer)
if not isinstance(self.trainer.deepspeed, type):
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
self.inference = Inference(**self.inference)
self.optimizations = Optimizations(**self.optimizations)
if isinstance( model.experimental, dict ):
model.experimental = ModelExperimentalSettings(**model.experimental)
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
@ -961,7 +992,7 @@ class Config(BaseConfig):
try:
from transformers import PreTrainedTokenizerFast
tokenizer_path = self.rel_path / self.tokenizer_path if self.yaml_path is not None else None
tokenizer_path = self.rel_path / self.tokenizer_path
if tokenizer_path and not tokenizer_path.exists():
tokenizer_path = Path("./data/") / self.tokenizer_path

View File

@ -57,6 +57,10 @@ def load_engines(training=True, **model_kwargs):
tag = open(checkpoint_path).read()
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
# if loaded using --model=
if cfg.model_path and cfg.model_path.exists():
load_path = cfg.model_path
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
_logger.warning(f"Checkpoint missing, but weights found: {load_path}")
loads_state_dict = True

View File

@ -19,6 +19,7 @@ from .models import get_models
from .models.lora import enable_lora
from .engines import load_engines, deepspeed_available
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
from .models import download_model, DEFAULT_MODEL_PATH
if deepspeed_available:
import deepspeed
@ -34,9 +35,18 @@ class TTS():
self.loading = False
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
if config:
if not config:
download_model()
config = DEFAULT_MODEL_PATH
if config.suffix == ".yaml":
_logger.info(f"Loading YAML: {config}")
cfg.load_yaml( config )
elif config.suffix == ".sft":
_logger.info(f"Loading model: {config}")
cfg.load_model( config )
else:
raise Exception(f"Unknown config passed: {config}")
try:
cfg.format( training=False )

View File

@ -4,31 +4,20 @@ import requests
from tqdm import tqdm
from pathlib import Path
import time
_logger = logging.getLogger(__name__)
# to-do: implement automatically downloading model
DEFAULT_MODEL_PATH = Path(__file__).parent.parent.parent / 'data/models'
DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent / 'data/models'
DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / "ar+nar-llama-8.sft"
DEFAULT_MODEL_URLS = {
'ar+nar-llama-8/fp32.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-llama-8/fp32.sft',
'ar+nar-llama-8.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-llama-8/fp32.sft',
}
# kludge, probably better to use HF's model downloader function
# to-do: write to a temp file then copy so downloads can be interrupted
def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
scale = 1
if unit == "KiB":
scale = (1024)
elif unit == "MiB":
scale = (1024 * 1024)
elif unit == "MiB":
scale = (1024 * 1024 * 1024)
elif unit == "KB":
scale = (1000)
elif unit == "MB":
scale = (1000 * 1000)
elif unit == "MB":
scale = (1000 * 1000 * 1000)
def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ):
name = save_path.name
url = DEFAULT_MODEL_URLS[name] if name in DEFAULT_MODEL_URLS else None
if url is None:
@ -37,19 +26,32 @@ def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
if not save_path.parent.exists():
save_path.parent.mkdir(parents=True, exist_ok=True)
r = requests.get(url, stream=True)
content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length']) // scale
headers = {}
# check if modified
if save_path.exists():
headers = {"If-Modified-Since": time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(save_path.stat().st_mtime))}
r = requests.get(url, headers=headers, stream=True)
# not modified
if r.status_code == 304:
r.close()
return
# to-do: validate lengths match
content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length'])
with open(save_path, 'wb') as f:
bar = tqdm( unit=unit, total=content_length )
bar = tqdm( unit='B', unit_scale=True, unit_divisor=1024, total=content_length, desc=f"Downloading: {name}" )
for chunk in r.iter_content(chunk_size=chunkSize):
if not chunk:
continue
bar.update( len(chunk) / scale )
bar.update( len(chunk))
f.write(chunk)
bar.close()
r.close()
def get_model(config, training=True, **model_kwargs):
name = config.name

View File

@ -21,6 +21,7 @@ from .utils.io import json_read, json_stringify
from .emb.qnt import decode_to_wave
from .data import get_lang_symmap, get_random_prompt
tts = None
layout = {}
@ -49,9 +50,9 @@ def gradio_wrapper(inputs):
return wrapped_function
return decorated
# returns a list of models, assuming the models are placed under ./training/ or ./models/
def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
yamls = []
# returns a list of models, assuming the models are placed under ./training/ or ./models/ or ./data/models/
def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data/models/")] ):
configs = []
for path in paths:
if not path.exists():
@ -60,10 +61,14 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
for yaml in path.glob("**/*.yaml"):
if "/logs/" in str(yaml):
continue
configs.append( yaml )
yamls.append( yaml )
for sft in path.glob("**/*.sft"):
if "/logs/" in str(sft):
continue
configs.append( sft )
return yamls
return configs
def get_dtypes():
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
@ -73,10 +78,10 @@ def get_attentions():
return AVAILABLE_ATTENTIONS + ["auto"]
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
def load_model( yaml, device, dtype, attention ):
gr.Info(f"Loading: {yaml}")
def load_model( config, device, dtype, attention ):
gr.Info(f"Loading: {config}")
try:
init_tts( yaml=Path(yaml), restart=True, device=device, dtype=dtype, attention=attention )
init_tts( config=Path(config), restart=True, device=device, dtype=dtype, attention=attention )
except Exception as e:
raise gr.Error(e)
gr.Info(f"Loaded model")
@ -107,7 +112,7 @@ def load_sample( speaker ):
return data, (sr, wav)
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=None):
def init_tts(config=None, restart=False, device="cuda", dtype="auto", attention=None):
global tts
if tts is not None:
@ -118,20 +123,32 @@ def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=No
tts = None
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--device", type=str, default=device)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=dtype)
parser.add_argument("--attention", type=str, default=attention)
args, unknown = parser.parse_known_args()
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
if config:
if config.suffix == ".yaml" and not args.yaml:
args.yaml = config
elif config.suffix == ".sft" and not args.model:
args.model = config
if args.yaml:
config = args.yaml
elif args.model:
config = args.model
tts = TTS( config=config, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
return tts
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if not cfg.yaml_path:
raise Exception("No YAML loaded.")
if not cfg.models:
raise Exception("No model loaded.")
if kwargs.pop("dynamic-sampling", False):
kwargs['min-ar-temp'] = 0.01 if kwargs['ar-temp'] > 0.01 else 0.0
@ -220,8 +237,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
@gradio_wrapper(inputs=layout["inference_stt"]["inputs"].keys())
def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if not cfg.yaml_path:
raise Exception("No YAML loaded.")
if not cfg.models:
raise Exception("No model loaded.")
if kwargs.pop("dynamic-sampling", False):
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
@ -306,6 +323,7 @@ def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
# setup args
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
parser.add_argument("--share", action="store_true")
parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ)
@ -462,7 +480,7 @@ with ui:
with gr.Row():
with gr.Column(scale=7):
with gr.Row():
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model")
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")