added option to explicitly load a lora without having to lobotomize yourself with creating a yaml just to do so
This commit is contained in:
parent
5018ddb107
commit
fee02f4153
|
@ -99,7 +99,7 @@ def torch_save( data, path, module_key=None ):
|
|||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
return sft_save( data, path, metadata )
|
||||
return sft_save( data, path, { k: v for k, v in metadata.items() if v is not None } )
|
||||
|
||||
return torch.save( data, path )
|
||||
|
||||
|
|
|
@ -85,21 +85,49 @@ def gradio_wrapper(inputs):
|
|||
return decorated
|
||||
|
||||
# returns a list of models, assuming the models are placed under ./training/ or ./models/ or ./data/models/
|
||||
def get_model_paths(paths=[Path("./training/"), Path("./models/"), Path("./data/models/")] ):
|
||||
def get_model_paths(paths=["./training/", "./models/", "./data/models/"] ):
|
||||
configs = []
|
||||
|
||||
for path in paths:
|
||||
if not isinstance( path, Path ):
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
continue
|
||||
|
||||
for yaml in path.glob("**/*.yaml"):
|
||||
if "/logs/" in str(yaml):
|
||||
continue
|
||||
if "lora" in str(yaml):
|
||||
continue
|
||||
configs.append( yaml )
|
||||
|
||||
for sft in path.glob("**/*.sft"):
|
||||
if "/logs/" in str(sft):
|
||||
continue
|
||||
if "lora" in str(sft):
|
||||
continue
|
||||
configs.append( sft )
|
||||
|
||||
configs = [ str(p) for p in configs ]
|
||||
|
||||
return configs
|
||||
|
||||
def get_lora_paths(paths=["./training/", "./models/", "./data/models/"] ):
|
||||
configs = []
|
||||
|
||||
for path in paths:
|
||||
if not isinstance( path, Path ):
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
continue
|
||||
|
||||
for sft in path.glob("**/*.sft"):
|
||||
if "/logs/" in str(sft):
|
||||
continue
|
||||
if "lora" not in str(sft):
|
||||
continue
|
||||
configs.append( sft )
|
||||
|
||||
configs = [ str(p) for p in configs ]
|
||||
|
@ -113,10 +141,10 @@ def get_attentions():
|
|||
return AVAILABLE_ATTENTIONS + ["auto"]
|
||||
|
||||
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
|
||||
def load_model( config, device, dtype, attention ):
|
||||
def load_model( config, lora, device, dtype, attention ):
|
||||
gr.Info(f"Loading: {config}")
|
||||
try:
|
||||
init_tts( config=Path(config), restart=True, device=device, dtype=dtype, attention=attention )
|
||||
init_tts( config=Path(config), lora=Path(lora) if lora is not None else None, restart=True, device=device, dtype=dtype, attention=attention )
|
||||
except Exception as e:
|
||||
raise gr.Error(e)
|
||||
gr.Info(f"Loaded model")
|
||||
|
@ -444,6 +472,7 @@ def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--lora", type=Path, default=os.environ.get('VALLE_LORA', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
|
||||
parser.add_argument("--share", action="store_true")
|
||||
parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ)
|
||||
|
@ -648,8 +677,9 @@ with ui:
|
|||
with gr.Column(scale=7):
|
||||
with gr.Row():
|
||||
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model", info="Model to load. Can load from a config YAML or the weights itself.")
|
||||
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device", info="Device to load the weights onto.")
|
||||
layout["settings"]["inputs"]["loras"] = gr.Dropdown(choices=get_lora_paths(), value=args.yaml or args.lora, label="LoRA", info="LoRA to load. Can load from a config YAML or the weights itself.")
|
||||
with gr.Row():
|
||||
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device", info="Device to load the weights onto.")
|
||||
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision", info="Tensor type to load the model under.")
|
||||
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions", info="Attention mechanism to utilize.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user