added option to explicitly load a lora without having to lobotomize yourself with creating a yaml just to do so

This commit is contained in:
mrq 2025-05-20 23:28:29 -05:00
parent 5018ddb107
commit fee02f4153
2 changed files with 35 additions and 5 deletions

View File

@ -99,7 +99,7 @@ def torch_save( data, path, module_key=None ):
if metadata is None:
metadata = {}
return sft_save( data, path, metadata )
return sft_save( data, path, { k: v for k, v in metadata.items() if v is not None } )
return torch.save( data, path )

View File

@ -85,21 +85,49 @@ def gradio_wrapper(inputs):
return decorated
# returns a list of models, assuming the models are placed under ./training/ or ./models/ or ./data/models/
def get_model_paths(paths=[Path("./training/"), Path("./models/"), Path("./data/models/")] ):
def get_model_paths(paths=["./training/", "./models/", "./data/models/"] ):
configs = []
for path in paths:
if not isinstance( path, Path ):
path = Path(path)
if not path.exists():
continue
for yaml in path.glob("**/*.yaml"):
if "/logs/" in str(yaml):
continue
if "lora" in str(yaml):
continue
configs.append( yaml )
for sft in path.glob("**/*.sft"):
if "/logs/" in str(sft):
continue
if "lora" in str(sft):
continue
configs.append( sft )
configs = [ str(p) for p in configs ]
return configs
def get_lora_paths(paths=["./training/", "./models/", "./data/models/"] ):
configs = []
for path in paths:
if not isinstance( path, Path ):
path = Path(path)
if not path.exists():
continue
for sft in path.glob("**/*.sft"):
if "/logs/" in str(sft):
continue
if "lora" not in str(sft):
continue
configs.append( sft )
configs = [ str(p) for p in configs ]
@ -113,10 +141,10 @@ def get_attentions():
return AVAILABLE_ATTENTIONS + ["auto"]
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
def load_model( config, device, dtype, attention ):
def load_model( config, lora, device, dtype, attention ):
gr.Info(f"Loading: {config}")
try:
init_tts( config=Path(config), restart=True, device=device, dtype=dtype, attention=attention )
init_tts( config=Path(config), lora=Path(lora) if lora is not None else None, restart=True, device=device, dtype=dtype, attention=attention )
except Exception as e:
raise gr.Error(e)
gr.Info(f"Loaded model")
@ -444,6 +472,7 @@ def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--lora", type=Path, default=os.environ.get('VALLE_LORA', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
parser.add_argument("--share", action="store_true")
parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ)
@ -648,8 +677,9 @@ with ui:
with gr.Column(scale=7):
with gr.Row():
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model", info="Model to load. Can load from a config YAML or the weights itself.")
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device", info="Device to load the weights onto.")
layout["settings"]["inputs"]["loras"] = gr.Dropdown(choices=get_lora_paths(), value=args.yaml or args.lora, label="LoRA", info="LoRA to load. Can load from a config YAML or the weights itself.")
with gr.Row():
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device", info="Device to load the weights onto.")
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision", info="Tensor type to load the model under.")
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions", info="Attention mechanism to utilize.")