From fee02f415312b2097ecaa5d474cdd63677c51963 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 20 May 2025 23:28:29 -0500 Subject: [PATCH] added option to explicitly load a lora without having to lobotomize yourself with creating a yaml just to do so --- vall_e/utils/io.py | 2 +- vall_e/webui.py | 38 ++++++++++++++++++++++++++++++++++---- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/vall_e/utils/io.py b/vall_e/utils/io.py index 5b1239c..7a9f74e 100644 --- a/vall_e/utils/io.py +++ b/vall_e/utils/io.py @@ -99,7 +99,7 @@ def torch_save( data, path, module_key=None ): if metadata is None: metadata = {} - return sft_save( data, path, metadata ) + return sft_save( data, path, { k: v for k, v in metadata.items() if v is not None } ) return torch.save( data, path ) diff --git a/vall_e/webui.py b/vall_e/webui.py index 83201ff..bfdfd8e 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -85,21 +85,49 @@ def gradio_wrapper(inputs): return decorated # returns a list of models, assuming the models are placed under ./training/ or ./models/ or ./data/models/ -def get_model_paths(paths=[Path("./training/"), Path("./models/"), Path("./data/models/")] ): +def get_model_paths(paths=["./training/", "./models/", "./data/models/"] ): configs = [] for path in paths: + if not isinstance( path, Path ): + path = Path(path) + if not path.exists(): continue for yaml in path.glob("**/*.yaml"): if "/logs/" in str(yaml): continue + if "lora" in str(yaml): + continue configs.append( yaml ) for sft in path.glob("**/*.sft"): if "/logs/" in str(sft): continue + if "lora" in str(sft): + continue + configs.append( sft ) + + configs = [ str(p) for p in configs ] + + return configs + +def get_lora_paths(paths=["./training/", "./models/", "./data/models/"] ): + configs = [] + + for path in paths: + if not isinstance( path, Path ): + path = Path(path) + + if not path.exists(): + continue + + for sft in path.glob("**/*.sft"): + if "/logs/" in str(sft): + continue + if "lora" not in str(sft): + continue configs.append( sft ) configs = [ str(p) for p in configs ] @@ -113,10 +141,10 @@ def get_attentions(): return AVAILABLE_ATTENTIONS + ["auto"] #@gradio_wrapper(inputs=layout["settings"]["inputs"].keys()) -def load_model( config, device, dtype, attention ): +def load_model( config, lora, device, dtype, attention ): gr.Info(f"Loading: {config}") try: - init_tts( config=Path(config), restart=True, device=device, dtype=dtype, attention=attention ) + init_tts( config=Path(config), lora=Path(lora) if lora is not None else None, restart=True, device=device, dtype=dtype, attention=attention ) except Exception as e: raise gr.Error(e) gr.Info(f"Loaded model") @@ -444,6 +472,7 @@ def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too +parser.add_argument("--lora", type=Path, default=os.environ.get('VALLE_LORA', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--listen", default=None, help="Path for Gradio to listen on") parser.add_argument("--share", action="store_true") parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ) @@ -648,8 +677,9 @@ with ui: with gr.Column(scale=7): with gr.Row(): layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model", info="Model to load. Can load from a config YAML or the weights itself.") - layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device", info="Device to load the weights onto.") + layout["settings"]["inputs"]["loras"] = gr.Dropdown(choices=get_lora_paths(), value=args.yaml or args.lora, label="LoRA", info="LoRA to load. Can load from a config YAML or the weights itself.") with gr.Row(): + layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device", info="Device to load the weights onto.") layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision", info="Tensor type to load the model under.") layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions", info="Attention mechanism to utilize.")