actually make deepspeed work with LoRAs
This commit is contained in:
parent
1d159b1476
commit
7047fcc6e2
|
@ -12,6 +12,8 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
|||
|
||||
from ..models import get_models
|
||||
from ..utils import wrapper as ml
|
||||
from ..models.lora import apply_lora
|
||||
|
||||
import torch
|
||||
import re
|
||||
|
||||
|
@ -30,6 +32,8 @@ def load_engines(training=True):
|
|||
engines = dict()
|
||||
|
||||
for name, model in models.items():
|
||||
hyper_config = model.config
|
||||
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
|
@ -51,6 +55,9 @@ def load_engines(training=True):
|
|||
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
||||
model.model = ml.replace_embedding( model.model )
|
||||
|
||||
for lora in cfg.loras:
|
||||
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy )
|
||||
|
||||
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||
optimizer_class = None
|
||||
scheduler_class = None
|
||||
|
@ -153,8 +160,6 @@ def load_engines(training=True):
|
|||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
hyper_config = model.config
|
||||
|
||||
# wrap if DDP is requested
|
||||
if ddp:
|
||||
model = ddp_model(model)
|
||||
|
@ -178,9 +183,6 @@ def load_engines(training=True):
|
|||
engines = Engines(engines)
|
||||
engines.setup()
|
||||
|
||||
for name, engine in engines.items():
|
||||
engine.load_loras()
|
||||
|
||||
if not cfg.trainer.load_state_dict:
|
||||
engines.load_checkpoint()
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ def default_feeder(engine, batch):
|
|||
from ..config import cfg
|
||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
|
||||
from ..models.lora import apply_lora, freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
|
||||
from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
@ -190,17 +190,6 @@ class Engine():
|
|||
if 'lora' in state:
|
||||
lora_load_state_dict( self.module, state['lora'] )
|
||||
|
||||
|
||||
def load_loras( self ):
|
||||
# apply lora weights
|
||||
for lora in cfg.loras:
|
||||
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy )
|
||||
|
||||
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
|
||||
if lora_path.exists():
|
||||
state_dict = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||
self.module = lora_load_state_dict( self.module, state_dict )
|
||||
|
||||
def eval(self):
|
||||
return self.module.eval()
|
||||
|
||||
|
@ -334,11 +323,11 @@ class Engines(dict[str, Engine]):
|
|||
for name, engine in self.items():
|
||||
module = engine.module.state_dict()
|
||||
lora = None
|
||||
save_dir = cfg.ckpt_dir / name / "fp32.pth"
|
||||
save_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
|
||||
if cfg.lora is not None:
|
||||
lora, module = lora_get_state_dict( module, split = True )
|
||||
save_dir = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth"
|
||||
save_path = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth"
|
||||
|
||||
state_dict = {
|
||||
'module': module,
|
||||
|
@ -352,9 +341,9 @@ class Engines(dict[str, Engine]):
|
|||
"userdata": userdata
|
||||
}
|
||||
if callback:
|
||||
state_dict = callback( state_dict, engine.hyper_config )
|
||||
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
||||
|
||||
torch.save(state_dict, save_dir)
|
||||
torch.save(state_dict, save_path)
|
||||
print(f"Exported {name} to {outpath}")
|
||||
|
||||
def save_checkpoint(self, tag=None):
|
||||
|
|
|
@ -68,7 +68,12 @@ class Engine(DeepSpeedEngine):
|
|||
|
||||
# freeze non-LoRA params if requested
|
||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||
return freeze_non_lora_weights( self.module )
|
||||
for name, param in self.module.named_parameters():
|
||||
should = 'lora_' in name
|
||||
param.requires_grad_(should)
|
||||
if not should:
|
||||
self._frozen_params.add(param)
|
||||
return
|
||||
|
||||
for name, param in self.module.named_parameters():
|
||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
||||
|
@ -128,16 +133,6 @@ class Engine(DeepSpeedEngine):
|
|||
|
||||
return super().save_checkpoint( save_dir, **kwargs )
|
||||
|
||||
def load_loras( self ):
|
||||
# apply lora weights
|
||||
for lora in cfg.loras:
|
||||
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy )
|
||||
|
||||
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
|
||||
if lora_path.exists():
|
||||
state_dict = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||
self.module = lora_load_state_dict( self.module, state_dict )
|
||||
|
||||
def traverse(self, *args, **kwargs):
|
||||
with ml.autocast():
|
||||
self.forward(*args, **kwargs)
|
||||
|
|
|
@ -6,9 +6,10 @@ import torch.nn
|
|||
from .data import get_phone_symmap
|
||||
from .engines import load_engines
|
||||
from .config import cfg
|
||||
from .models.lora import lora_get_state_dict
|
||||
|
||||
# stitches embeddings into one embedding + classifier => lm_head
|
||||
def convert_to_hf( state_dict, config = None ):
|
||||
# stitches embeddings into one embedding & classifier => lm_head
|
||||
def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
|
||||
token_dim = 1024
|
||||
embedding = torch.nn.Embedding(n_tokens, token_dim)
|
||||
|
@ -53,22 +54,46 @@ def convert_to_hf( state_dict, config = None ):
|
|||
state_dict['module']['lm_head.weight'] = out_proj
|
||||
del state_dict['module']['classifier.bias']
|
||||
|
||||
torch.save(state_dict, "./data/export_test.pth")
|
||||
return state_dict
|
||||
|
||||
raise Exception("!")
|
||||
def extract_lora( state_dict, config = None, save_path = None ):
|
||||
lora = state_dict["lora"] if "lora" in state_dict else None
|
||||
# should always be included, but just in case
|
||||
if lora is None and "module" in state_dict:
|
||||
lora, module = lora_get_state_dict( state_dict["module"], split = True )
|
||||
state_dict["module"] = module
|
||||
state_dict["lora"] = lora
|
||||
|
||||
# should raise an exception since there's nothing to extract, or at least a warning
|
||||
if not lora:
|
||||
return state_dict
|
||||
|
||||
# save lora specifically
|
||||
# should probably export other attributes, similar to what SD LoRAs do
|
||||
save_path = save_path.parent / "lora.pth"
|
||||
torch.save( { "module": lora }, save_path )
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--module-only", action='store_true')
|
||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
||||
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
if args.module_only:
|
||||
cfg.trainer.load_module_only = True
|
||||
|
||||
callback = convert_to_hf if args.hf else None
|
||||
callback = None
|
||||
if args.hf:
|
||||
callback = convert_to_hf
|
||||
elif args.lora:
|
||||
callback = extract_lora
|
||||
|
||||
if args.hf and args.lora:
|
||||
raise Exception("Requesting more than one callback")
|
||||
|
||||
engines = load_engines()
|
||||
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
||||
|
|
|
@ -14,7 +14,7 @@ USE_PARAMETRIZATION = True
|
|||
|
||||
# LoRA Linear for replacement
|
||||
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
||||
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc)
|
||||
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
|
||||
class Linear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue
Block a user