From ccf71dc1b6ff8b6256382dbb12a535396caaf260 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 25 Oct 2024 22:15:15 -0500 Subject: [PATCH] added option to load from a model state dict directly instead of a yaml (to-do: do this for LoRAs too), automatically download the default model if none is provided --- vall_e/config.py | 77 ++++++++++++++++++++++++++------------ vall_e/engines/__init__.py | 4 ++ vall_e/inference.py | 12 +++++- vall_e/models/__init__.py | 46 ++++++++++++----------- vall_e/webui.py | 52 ++++++++++++++++--------- 5 files changed, 128 insertions(+), 63 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 0d73e5a..e486681 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -21,6 +21,7 @@ from functools import cached_property from pathlib import Path from .utils.distributed import world_size +from .utils.io import torch_load from .utils import set_seed, prune_missing @dataclass() @@ -30,7 +31,13 @@ class BaseConfig: @property def cfg_path(self): - return Path(self.yaml_path.parent) if self.yaml_path is not None else Path(__file__).parent.parent / "data" + if self.yaml_path: + return Path(self.yaml_path.parent) + + if self.model_path: + return Path(self.model_path.parent) + + return Path(__file__).parent.parent / "data" @property def rel_path(self): @@ -93,8 +100,6 @@ class BaseConfig: def prune_missing( cls, yaml ): default = cls(**{}) default.format() - #default = json.loads(default.dumps()) - yaml, missing = prune_missing( source=default, dest=yaml ) if missing: _logger.warning(f'Missing keys in YAML: {missing}') @@ -108,6 +113,17 @@ class BaseConfig: state = cls.prune_missing( state ) return cls(**state) + @classmethod + def from_model( cls, model_path ): + if not model_path.exists(): + raise Exception(f'Model path does not exist: {model_path}') + + # load state dict and copy its stored model config + state_dict = torch_load( model_path )["config"] + + state = { "models": [ state_dict ], "trainer": { "load_state_dict": True }, "model_path": model_path } + return cls(**state) + @classmethod def from_cli(cls, args=sys.argv): # legacy support for yaml=`` format @@ -117,8 +133,12 @@ class BaseConfig: parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False) parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too + parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too args, unknown = parser.parse_known_args(args=args) + if args.model: + return cls.from_model( args.model ) + if args.yaml: return cls.from_yaml( args.yaml ) @@ -807,10 +827,14 @@ class Config(BaseConfig): return diskcache.Cache(self.cache_dir).memoize return lambda: lambda x: x - # I don't remember why this is needed + # this gets called from vall_e.inference def load_yaml( self, config_path ): tmp = Config.from_yaml( config_path ) self.__dict__.update(tmp.__dict__) + + def load_model( self, config_path ): + tmp = Config.from_model( config_path ) + self.__dict__.update(tmp.__dict__) def load_hdf5( self, write=False ): if hasattr(self, 'hdf5'): @@ -870,7 +894,27 @@ class Config(BaseConfig): if isinstance(self.optimizations, type): self.optimizations = dict() - self.dataset = Dataset(**self.dataset) + if isinstance( self.dataset, dict ): + self.dataset = Dataset(**self.dataset) + + if isinstance( self.hyperparameters, dict ): + self.hyperparameters = Hyperparameters(**self.hyperparameters) + + if isinstance( self.evaluation, dict ): + self.evaluation = Evaluation(**self.evaluation) + + if isinstance( self.trainer, dict ): + self.trainer = Trainer(**self.trainer) + + if isinstance( self.trainer.deepspeed, dict ): + self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed) + + if isinstance( self.inference, dict ): + self.inference = Inference(**self.inference) + + if isinstance( self.optimizations, dict ): + self.optimizations = Optimizations(**self.optimizations) + # convert to expanded paths self.dataset.training = [ self.expand(dir) for dir in self.dataset.training ] self.dataset.validation = [ self.expand(dir) for dir in self.dataset.validation ] @@ -906,28 +950,15 @@ class Config(BaseConfig): model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums") - self.models = [ Model(**model) for model in self.models ] - self.loras = [ LoRA(**lora) for lora in self.loras ] + self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ] + self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ] if not self.models: self.models = [ Model() ] for model in self.models: - if not isinstance( model.experimental, dict ): - continue - model.experimental = ModelExperimentalSettings(**model.experimental) - - self.hyperparameters = Hyperparameters(**self.hyperparameters) - - self.evaluation = Evaluation(**self.evaluation) - - self.trainer = Trainer(**self.trainer) - - if not isinstance(self.trainer.deepspeed, type): - self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed) - - self.inference = Inference(**self.inference) - self.optimizations = Optimizations(**self.optimizations) + if isinstance( model.experimental, dict ): + model.experimental = ModelExperimentalSettings(**model.experimental) if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler: self.hyperparameters.scheduler = self.hyperparameters.scheduler_type @@ -961,7 +992,7 @@ class Config(BaseConfig): try: from transformers import PreTrainedTokenizerFast - tokenizer_path = self.rel_path / self.tokenizer_path if self.yaml_path is not None else None + tokenizer_path = self.rel_path / self.tokenizer_path if tokenizer_path and not tokenizer_path.exists(): tokenizer_path = Path("./data/") / self.tokenizer_path diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 0feec2e..a4c4901 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -57,6 +57,10 @@ def load_engines(training=True, **model_kwargs): tag = open(checkpoint_path).read() checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) + # if loaded using --model= + if cfg.model_path and cfg.model_path.exists(): + load_path = cfg.model_path + if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): _logger.warning(f"Checkpoint missing, but weights found: {load_path}") loads_state_dict = True diff --git a/vall_e/inference.py b/vall_e/inference.py index 6463f18..88684d1 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -19,6 +19,7 @@ from .models import get_models from .models.lora import enable_lora from .engines import load_engines, deepspeed_available from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize +from .models import download_model, DEFAULT_MODEL_PATH if deepspeed_available: import deepspeed @@ -34,9 +35,18 @@ class TTS(): self.loading = False def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ): - if config: + if not config: + download_model() + config = DEFAULT_MODEL_PATH + + if config.suffix == ".yaml": _logger.info(f"Loading YAML: {config}") cfg.load_yaml( config ) + elif config.suffix == ".sft": + _logger.info(f"Loading model: {config}") + cfg.load_model( config ) + else: + raise Exception(f"Unknown config passed: {config}") try: cfg.format( training=False ) diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index c474a9e..2ff6c53 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -4,31 +4,20 @@ import requests from tqdm import tqdm from pathlib import Path +import time + _logger = logging.getLogger(__name__) # to-do: implement automatically downloading model -DEFAULT_MODEL_PATH = Path(__file__).parent.parent.parent / 'data/models' +DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent / 'data/models' +DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / "ar+nar-llama-8.sft" DEFAULT_MODEL_URLS = { - 'ar+nar-llama-8/fp32.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-llama-8/fp32.sft', + 'ar+nar-llama-8.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-llama-8/fp32.sft', } # kludge, probably better to use HF's model downloader function # to-do: write to a temp file then copy so downloads can be interrupted -def download_model( save_path, chunkSize = 1024, unit = "MiB" ): - scale = 1 - if unit == "KiB": - scale = (1024) - elif unit == "MiB": - scale = (1024 * 1024) - elif unit == "MiB": - scale = (1024 * 1024 * 1024) - elif unit == "KB": - scale = (1000) - elif unit == "MB": - scale = (1000 * 1000) - elif unit == "MB": - scale = (1000 * 1000 * 1000) - +def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ): name = save_path.name url = DEFAULT_MODEL_URLS[name] if name in DEFAULT_MODEL_URLS else None if url is None: @@ -37,19 +26,32 @@ def download_model( save_path, chunkSize = 1024, unit = "MiB" ): if not save_path.parent.exists(): save_path.parent.mkdir(parents=True, exist_ok=True) - r = requests.get(url, stream=True) - content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length']) // scale + headers = {} + # check if modified + if save_path.exists(): + headers = {"If-Modified-Since": time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(save_path.stat().st_mtime))} + + r = requests.get(url, headers=headers, stream=True) + # not modified + if r.status_code == 304: + r.close() + return + + # to-do: validate lengths match + + content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length']) with open(save_path, 'wb') as f: - bar = tqdm( unit=unit, total=content_length ) + bar = tqdm( unit='B', unit_scale=True, unit_divisor=1024, total=content_length, desc=f"Downloading: {name}" ) for chunk in r.iter_content(chunk_size=chunkSize): if not chunk: continue - - bar.update( len(chunk) / scale ) + bar.update( len(chunk)) f.write(chunk) bar.close() + r.close() + def get_model(config, training=True, **model_kwargs): name = config.name diff --git a/vall_e/webui.py b/vall_e/webui.py index c8b71c1..6cc420a 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -21,6 +21,7 @@ from .utils.io import json_read, json_stringify from .emb.qnt import decode_to_wave from .data import get_lang_symmap, get_random_prompt + tts = None layout = {} @@ -49,9 +50,9 @@ def gradio_wrapper(inputs): return wrapped_function return decorated -# returns a list of models, assuming the models are placed under ./training/ or ./models/ -def get_model_paths( paths=[Path("./training/"), Path("./models/")] ): - yamls = [] +# returns a list of models, assuming the models are placed under ./training/ or ./models/ or ./data/models/ +def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data/models/")] ): + configs = [] for path in paths: if not path.exists(): @@ -60,10 +61,14 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/")] ): for yaml in path.glob("**/*.yaml"): if "/logs/" in str(yaml): continue + configs.append( yaml ) + + for sft in path.glob("**/*.sft"): + if "/logs/" in str(sft): + continue + configs.append( sft ) - yamls.append( yaml ) - - return yamls + return configs def get_dtypes(): return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"] @@ -73,10 +78,10 @@ def get_attentions(): return AVAILABLE_ATTENTIONS + ["auto"] #@gradio_wrapper(inputs=layout["settings"]["inputs"].keys()) -def load_model( yaml, device, dtype, attention ): - gr.Info(f"Loading: {yaml}") +def load_model( config, device, dtype, attention ): + gr.Info(f"Loading: {config}") try: - init_tts( yaml=Path(yaml), restart=True, device=device, dtype=dtype, attention=attention ) + init_tts( config=Path(config), restart=True, device=device, dtype=dtype, attention=attention ) except Exception as e: raise gr.Error(e) gr.Info(f"Loaded model") @@ -107,7 +112,7 @@ def load_sample( speaker ): return data, (sr, wav) -def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=None): +def init_tts(config=None, restart=False, device="cuda", dtype="auto", attention=None): global tts if tts is not None: @@ -118,20 +123,32 @@ def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=No tts = None parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False) - parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too + parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too + parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--device", type=str, default=device) parser.add_argument("--amp", action="store_true") parser.add_argument("--dtype", type=str, default=dtype) parser.add_argument("--attention", type=str, default=attention) args, unknown = parser.parse_known_args() - tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention ) + if config: + if config.suffix == ".yaml" and not args.yaml: + args.yaml = config + elif config.suffix == ".sft" and not args.model: + args.model = config + + if args.yaml: + config = args.yaml + elif args.model: + config = args.model + + tts = TTS( config=config, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention ) return tts @gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys()) def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): - if not cfg.yaml_path: - raise Exception("No YAML loaded.") + if not cfg.models: + raise Exception("No model loaded.") if kwargs.pop("dynamic-sampling", False): kwargs['min-ar-temp'] = 0.01 if kwargs['ar-temp'] > 0.01 else 0.0 @@ -220,8 +237,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): @gradio_wrapper(inputs=layout["inference_stt"]["inputs"].keys()) def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): - if not cfg.yaml_path: - raise Exception("No YAML loaded.") + if not cfg.models: + raise Exception("No model loaded.") if kwargs.pop("dynamic-sampling", False): kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0 @@ -306,6 +323,7 @@ def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): # setup args parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too +parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--listen", default=None, help="Path for Gradio to listen on") parser.add_argument("--share", action="store_true") parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ) @@ -462,7 +480,7 @@ with ui: with gr.Row(): with gr.Column(scale=7): with gr.Row(): - layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model") + layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model") layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device") layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision") layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")