load exported LoRA weights if exists (to-do: make a better LoRA loading mechanism)
This commit is contained in:
parent
7c9144ff22
commit
6c2e00ce2a
|
@ -166,6 +166,14 @@ def load_engines(training=True):
|
|||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
# load lora weights if exists
|
||||
if cfg.lora is not None:
|
||||
lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth"
|
||||
if lora_path.exists():
|
||||
state = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||
state = state['lora' if 'lora' in state else 'module']
|
||||
model.load_state_dict(state, strict=False)
|
||||
|
||||
# wrap if DDP is requested
|
||||
if ddp:
|
||||
model = ddp_model(model)
|
||||
|
|
Loading…
Reference in New Issue
Block a user