updated export routine to split LoRA weights from the state dict (should work with deepspeed)
This commit is contained in:
parent
726a4b613f
commit
1d159b1476
|
@ -793,10 +793,6 @@ class Config(BaseConfig):
|
||||||
if not training:
|
if not training:
|
||||||
self.dataset.use_hdf5 = False
|
self.dataset.use_hdf5 = False
|
||||||
|
|
||||||
# raise error if DeepSpeed and a LoRA is loaded, because I don't support it yet
|
|
||||||
if self.trainer.backend == "deepspeed" and self.lora is not None:
|
|
||||||
raise Exception("LoRAs are currently unsupported with deepspeed backend")
|
|
||||||
|
|
||||||
# load our HDF5 file if requested here
|
# load our HDF5 file if requested here
|
||||||
if self.dataset.use_hdf5:
|
if self.dataset.use_hdf5:
|
||||||
self.load_hdf5()
|
self.load_hdf5()
|
||||||
|
|
|
@ -331,14 +331,18 @@ class Engines(dict[str, Engine]):
|
||||||
engine.dispatch_attribute(*args, **kwargs)
|
engine.dispatch_attribute(*args, **kwargs)
|
||||||
|
|
||||||
def export(self, userdata={}, callback=None):
|
def export(self, userdata={}, callback=None):
|
||||||
# to-do: lora exporting
|
|
||||||
if cfg.lora is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
module = engine.module.state_dict()
|
||||||
|
lora = None
|
||||||
|
save_dir = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
|
|
||||||
|
if cfg.lora is not None:
|
||||||
|
lora, module = lora_get_state_dict( module, split = True )
|
||||||
|
save_dir = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth"
|
||||||
|
|
||||||
state_dict = {
|
state_dict = {
|
||||||
'module': engine.module.state_dict(),
|
'module': module,
|
||||||
|
'lora': lora,
|
||||||
"stats": {
|
"stats": {
|
||||||
"global_step": engine.global_step,
|
"global_step": engine.global_step,
|
||||||
"micro_step": engine.micro_step,
|
"micro_step": engine.micro_step,
|
||||||
|
@ -349,7 +353,8 @@ class Engines(dict[str, Engine]):
|
||||||
}
|
}
|
||||||
if callback:
|
if callback:
|
||||||
state_dict = callback( state_dict, engine.hyper_config )
|
state_dict = callback( state_dict, engine.hyper_config )
|
||||||
torch.save(state_dict, outpath)
|
|
||||||
|
torch.save(state_dict, save_dir)
|
||||||
print(f"Exported {name} to {outpath}")
|
print(f"Exported {name} to {outpath}")
|
||||||
|
|
||||||
def save_checkpoint(self, tag=None):
|
def save_checkpoint(self, tag=None):
|
||||||
|
|
|
@ -10,7 +10,7 @@ import math
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
# to-do: set cfg to decide
|
# to-do: set cfg to decide
|
||||||
USE_PARAMETRIZATION = False
|
USE_PARAMETRIZATION = True
|
||||||
|
|
||||||
# LoRA Linear for replacement
|
# LoRA Linear for replacement
|
||||||
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
||||||
|
|
Loading…
Reference in New Issue
Block a user