updated export routine to split LoRA weights from the state dict (should work with deepspeed)

This commit is contained in:
mrq 2024-06-17 13:28:18 -05:00
parent 726a4b613f
commit 1d159b1476
3 changed files with 13 additions and 12 deletions

View File

@ -793,10 +793,6 @@ class Config(BaseConfig):
if not training:
self.dataset.use_hdf5 = False
# raise error if DeepSpeed and a LoRA is loaded, because I don't support it yet
if self.trainer.backend == "deepspeed" and self.lora is not None:
raise Exception("LoRAs are currently unsupported with deepspeed backend")
# load our HDF5 file if requested here
if self.dataset.use_hdf5:
self.load_hdf5()

View File

@ -331,14 +331,18 @@ class Engines(dict[str, Engine]):
engine.dispatch_attribute(*args, **kwargs)
def export(self, userdata={}, callback=None):
# to-do: lora exporting
if cfg.lora is not None:
return
for name, engine in self.items():
outpath = cfg.ckpt_dir / name / "fp32.pth"
module = engine.module.state_dict()
lora = None
save_dir = cfg.ckpt_dir / name / "fp32.pth"
if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True )
save_dir = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth"
state_dict = {
'module': engine.module.state_dict(),
'module': module,
'lora': lora,
"stats": {
"global_step": engine.global_step,
"micro_step": engine.micro_step,
@ -349,7 +353,8 @@ class Engines(dict[str, Engine]):
}
if callback:
state_dict = callback( state_dict, engine.hyper_config )
torch.save(state_dict, outpath)
torch.save(state_dict, save_dir)
print(f"Exported {name} to {outpath}")
def save_checkpoint(self, tag=None):

View File

@ -10,7 +10,7 @@ import math
from typing import Optional, List
# to-do: set cfg to decide
USE_PARAMETRIZATION = False
USE_PARAMETRIZATION = True
# LoRA Linear for replacement
# Pros: simple, just needs to reuse the replace_linear and copy weights