From 4049f51ba9c601a17d6c7ab668206cb54645a54c Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 26 Oct 2024 00:13:10 -0500 Subject: [PATCH] added option to load lora directly from the model file itself with --lora --- vall_e/__main__.py | 11 ++++++++++- vall_e/config.py | 24 +++++++++++++----------- vall_e/demo.py | 10 +++++++++- vall_e/engines/__init__.py | 10 +++++++--- vall_e/inference.py | 8 ++++---- vall_e/webui.py | 11 +++++++++-- 6 files changed, 52 insertions(+), 22 deletions(-) diff --git a/vall_e/__main__.py b/vall_e/__main__.py index c9edacc..6308554 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -17,6 +17,8 @@ def main(): parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None) + parser.add_argument("--model", type=Path, default=None) + parser.add_argument("--lora", type=Path, default=None) parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second) parser.add_argument("--max-nar-levels", type=int, default=7) @@ -53,7 +55,14 @@ def main(): parser.add_argument("--attention", type=str, default=None) args = parser.parse_args() - tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention ) + config = None + + if args.yaml: + config = args.yaml + elif args.model: + config = args.model + + tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention ) output = tts.inference( text=args.text, references=args.references, diff --git a/vall_e/config.py b/vall_e/config.py index e486681..d176a11 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -27,15 +27,11 @@ from .utils import set_seed, prune_missing @dataclass() class BaseConfig: yaml_path: str | None = None # path passed in through --yaml - model_path: str | None = None # path passed in through --model @property def cfg_path(self): if self.yaml_path: return Path(self.yaml_path.parent) - - if self.model_path: - return Path(self.model_path.parent) return Path(__file__).parent.parent / "data" @@ -114,14 +110,15 @@ class BaseConfig: return cls(**state) @classmethod - def from_model( cls, model_path ): + def from_model( cls, model_path, lora_path=None ): if not model_path.exists(): raise Exception(f'Model path does not exist: {model_path}') # load state dict and copy its stored model config - state_dict = torch_load( model_path )["config"] + model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path } ] if model_path and model_path.exists() else [] + lora_state_dict = [ torch_load( lora_path )["config"] | { "path": lora_path } ] if lora_path and lora_path.exists() else [] - state = { "models": [ state_dict ], "trainer": { "load_state_dict": True }, "model_path": model_path } + state = { "models": model_state_dict, "loras": lora_state_dict, "trainer": { "load_state_dict": True } } return cls(**state) @classmethod @@ -134,10 +131,11 @@ class BaseConfig: parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False) parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too + parser.add_argument("--lora", type=Path, default=os.environ.get('VALLE_LORA', None)) # os environ so it can be specified in a HuggingFace Space too args, unknown = parser.parse_known_args(args=args) if args.model: - return cls.from_model( args.model ) + return cls.from_model( args.model, args.lora ) if args.yaml: return cls.from_yaml( args.yaml ) @@ -276,6 +274,7 @@ class Model: frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training attention: str = "auto" # for llama arch_types: attention used dropout: float = 0.1 # adjustable dropout value + path: Path | None = None #loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good loss_factors: dict = field(default_factory=lambda: {}) capabilities: list = field(default_factory=lambda: ["ar", "nar"]) # + ["lang", "tone"] if you have your dataset labeled for such @@ -408,6 +407,7 @@ class LoRA: embeddings: bool = False # train the embedding too parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA + path: Path | None = None @property def full_name(self): @@ -832,8 +832,8 @@ class Config(BaseConfig): tmp = Config.from_yaml( config_path ) self.__dict__.update(tmp.__dict__) - def load_model( self, config_path ): - tmp = Config.from_model( config_path ) + def load_model( self, config_path, lora_path=None ): + tmp = Config.from_model( config_path, lora_path ) self.__dict__.update(tmp.__dict__) def load_hdf5( self, write=False ): @@ -870,6 +870,9 @@ class Config(BaseConfig): def format( self, training=True ): + print( self.models ) + print( self.loras ) + if isinstance(self.dataset, type): self.dataset = dict() @@ -949,7 +952,6 @@ class Config(BaseConfig): _logger.warning(f"Deprecated flag found: {'cfg.model.p_rvq_levels'}") model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums") - self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ] self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ] diff --git a/vall_e/demo.py b/vall_e/demo.py index 9d46e5d..5064065 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -42,6 +42,8 @@ def main(): parser = argparse.ArgumentParser("VALL-E TTS Demo") parser.add_argument("--yaml", type=Path, default=None) + parser.add_argument("--model", type=Path, default=None) + parser.add_argument("--lora", type=Path, default=None) parser.add_argument("--demo-dir", type=Path, default=None) parser.add_argument("--skip-existing", action="store_true") @@ -93,8 +95,14 @@ def main(): parser.add_argument("--comparison", type=str, default=None) args = parser.parse_args() + + config = None + if args.yaml: + config = args.yaml + elif args.model: + config = args.model - tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp ) + tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp ) if not args.demo_dir: args.demo_dir = Path("./data/demo/") diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index a4c4901..597b9d2 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -58,8 +58,8 @@ def load_engines(training=True, **model_kwargs): checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) # if loaded using --model= - if cfg.model_path and cfg.model_path.exists(): - load_path = cfg.model_path + if model.config.path and model.config.path.exists(): + load_path = model.config.path if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): _logger.warning(f"Checkpoint missing, but weights found: {load_path}") @@ -208,7 +208,11 @@ def load_engines(training=True, **model_kwargs): # load lora weights if exists if cfg.lora is not None: - lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) + if cfg.lora.path: + lora_path = cfg.lora.path + else: + lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) + if lora_path.exists(): _logger.info( f"Loaded LoRA state dict: {lora_path}" ) diff --git a/vall_e/inference.py b/vall_e/inference.py index 88684d1..c9e7786 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -25,16 +25,16 @@ if deepspeed_available: import deepspeed class TTS(): - def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ): + def __init__( self, config=None, lora=None, device=None, amp=None, dtype=None, attention=None ): self.loading = True # yes I can just grab **kwargs and forward them here - self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention ) + self.load_config( config=config, lora=lora, device=device, amp=amp, dtype=dtype, attention=attention ) self.load_model() self.loading = False - def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ): + def load_config( self, config=None, lora=None, device=None, amp=None, dtype=None, attention=None ): if not config: download_model() config = DEFAULT_MODEL_PATH @@ -44,7 +44,7 @@ class TTS(): cfg.load_yaml( config ) elif config.suffix == ".sft": _logger.info(f"Loading model: {config}") - cfg.load_model( config ) + cfg.load_model( config, lora ) else: raise Exception(f"Unknown config passed: {config}") diff --git a/vall_e/webui.py b/vall_e/webui.py index 6cc420a..1c29f8a 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -112,7 +112,7 @@ def load_sample( speaker ): return data, (sr, wav) -def init_tts(config=None, restart=False, device="cuda", dtype="auto", attention=None): +def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto", attention=None): global tts if tts is not None: @@ -125,6 +125,7 @@ def init_tts(config=None, restart=False, device="cuda", dtype="auto", attention= parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False) parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too + parser.add_argument("--lora", type=Path, default=os.environ.get('VALLE_LORA', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--device", type=str, default=device) parser.add_argument("--amp", action="store_true") parser.add_argument("--dtype", type=str, default=dtype) @@ -137,12 +138,18 @@ def init_tts(config=None, restart=False, device="cuda", dtype="auto", attention= elif config.suffix == ".sft" and not args.model: args.model = config + if lora and not args.lora: + args.lora = lora + if args.yaml: config = args.yaml elif args.model: config = args.model - tts = TTS( config=config, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention ) + if args.lora: + lora = args.lora + + tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention ) return tts @gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())