updated export routine to split LoRA weights from the state dict (should work with deepspeed)
This commit is contained in:
parent
726a4b613f
commit
1d159b1476
|
@ -793,10 +793,6 @@ class Config(BaseConfig):
|
|||
if not training:
|
||||
self.dataset.use_hdf5 = False
|
||||
|
||||
# raise error if DeepSpeed and a LoRA is loaded, because I don't support it yet
|
||||
if self.trainer.backend == "deepspeed" and self.lora is not None:
|
||||
raise Exception("LoRAs are currently unsupported with deepspeed backend")
|
||||
|
||||
# load our HDF5 file if requested here
|
||||
if self.dataset.use_hdf5:
|
||||
self.load_hdf5()
|
||||
|
|
|
@ -331,14 +331,18 @@ class Engines(dict[str, Engine]):
|
|||
engine.dispatch_attribute(*args, **kwargs)
|
||||
|
||||
def export(self, userdata={}, callback=None):
|
||||
# to-do: lora exporting
|
||||
if cfg.lora is not None:
|
||||
return
|
||||
|
||||
for name, engine in self.items():
|
||||
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
||||
module = engine.module.state_dict()
|
||||
lora = None
|
||||
save_dir = cfg.ckpt_dir / name / "fp32.pth"
|
||||
|
||||
if cfg.lora is not None:
|
||||
lora, module = lora_get_state_dict( module, split = True )
|
||||
save_dir = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth"
|
||||
|
||||
state_dict = {
|
||||
'module': engine.module.state_dict(),
|
||||
'module': module,
|
||||
'lora': lora,
|
||||
"stats": {
|
||||
"global_step": engine.global_step,
|
||||
"micro_step": engine.micro_step,
|
||||
|
@ -349,7 +353,8 @@ class Engines(dict[str, Engine]):
|
|||
}
|
||||
if callback:
|
||||
state_dict = callback( state_dict, engine.hyper_config )
|
||||
torch.save(state_dict, outpath)
|
||||
|
||||
torch.save(state_dict, save_dir)
|
||||
print(f"Exported {name} to {outpath}")
|
||||
|
||||
def save_checkpoint(self, tag=None):
|
||||
|
|
|
@ -10,7 +10,7 @@ import math
|
|||
from typing import Optional, List
|
||||
|
||||
# to-do: set cfg to decide
|
||||
USE_PARAMETRIZATION = False
|
||||
USE_PARAMETRIZATION = True
|
||||
|
||||
# LoRA Linear for replacement
|
||||
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
||||
|
|
Loading…
Reference in New Issue
Block a user