actually make deepspeed work with LoRAs

This commit is contained in:
mrq 2024-06-17 13:55:37 -05:00
parent 1d159b1476
commit 7047fcc6e2
5 changed files with 49 additions and 38 deletions

View File

@ -12,6 +12,8 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
from ..models import get_models from ..models import get_models
from ..utils import wrapper as ml from ..utils import wrapper as ml
from ..models.lora import apply_lora
import torch import torch
import re import re
@ -30,6 +32,8 @@ def load_engines(training=True):
engines = dict() engines = dict()
for name, model in models.items(): for name, model in models.items():
hyper_config = model.config
optimizer = None optimizer = None
lr_scheduler = None lr_scheduler = None
@ -51,6 +55,9 @@ def load_engines(training=True):
if cfg.optimizations.replace and cfg.optimizations.embedding: if cfg.optimizations.replace and cfg.optimizations.embedding:
model.model = ml.replace_embedding( model.model ) model.model = ml.replace_embedding( model.model )
for lora in cfg.loras:
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy )
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
optimizer_class = None optimizer_class = None
scheduler_class = None scheduler_class = None
@ -153,8 +160,6 @@ def load_engines(training=True):
model.load_state_dict(state, strict=cfg.trainer.strict_loading) model.load_state_dict(state, strict=cfg.trainer.strict_loading)
hyper_config = model.config
# wrap if DDP is requested # wrap if DDP is requested
if ddp: if ddp:
model = ddp_model(model) model = ddp_model(model)
@ -178,9 +183,6 @@ def load_engines(training=True):
engines = Engines(engines) engines = Engines(engines)
engines.setup() engines.setup()
for name, engine in engines.items():
engine.load_loras()
if not cfg.trainer.load_state_dict: if not cfg.trainer.load_state_dict:
engines.load_checkpoint() engines.load_checkpoint()

View File

@ -29,7 +29,7 @@ def default_feeder(engine, batch):
from ..config import cfg from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
from ..models.lora import apply_lora, freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
import logging import logging
import time import time
@ -190,17 +190,6 @@ class Engine():
if 'lora' in state: if 'lora' in state:
lora_load_state_dict( self.module, state['lora'] ) lora_load_state_dict( self.module, state['lora'] )
def load_loras( self ):
# apply lora weights
for lora in cfg.loras:
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy )
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
if lora_path.exists():
state_dict = torch.load(lora_path, map_location=torch.device(cfg.device))
self.module = lora_load_state_dict( self.module, state_dict )
def eval(self): def eval(self):
return self.module.eval() return self.module.eval()
@ -334,11 +323,11 @@ class Engines(dict[str, Engine]):
for name, engine in self.items(): for name, engine in self.items():
module = engine.module.state_dict() module = engine.module.state_dict()
lora = None lora = None
save_dir = cfg.ckpt_dir / name / "fp32.pth" save_path = cfg.ckpt_dir / name / "fp32.pth"
if cfg.lora is not None: if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True ) lora, module = lora_get_state_dict( module, split = True )
save_dir = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth" save_path = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth"
state_dict = { state_dict = {
'module': module, 'module': module,
@ -352,9 +341,9 @@ class Engines(dict[str, Engine]):
"userdata": userdata "userdata": userdata
} }
if callback: if callback:
state_dict = callback( state_dict, engine.hyper_config ) state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
torch.save(state_dict, save_dir) torch.save(state_dict, save_path)
print(f"Exported {name} to {outpath}") print(f"Exported {name} to {outpath}")
def save_checkpoint(self, tag=None): def save_checkpoint(self, tag=None):

View File

@ -68,7 +68,12 @@ class Engine(DeepSpeedEngine):
# freeze non-LoRA params if requested # freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None: if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
return freeze_non_lora_weights( self.module ) for name, param in self.module.named_parameters():
should = 'lora_' in name
param.requires_grad_(should)
if not should:
self._frozen_params.add(param)
return
for name, param in self.module.named_parameters(): for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params): if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
@ -128,16 +133,6 @@ class Engine(DeepSpeedEngine):
return super().save_checkpoint( save_dir, **kwargs ) return super().save_checkpoint( save_dir, **kwargs )
def load_loras( self ):
# apply lora weights
for lora in cfg.loras:
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy )
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
if lora_path.exists():
state_dict = torch.load(lora_path, map_location=torch.device(cfg.device))
self.module = lora_load_state_dict( self.module, state_dict )
def traverse(self, *args, **kwargs): def traverse(self, *args, **kwargs):
with ml.autocast(): with ml.autocast():
self.forward(*args, **kwargs) self.forward(*args, **kwargs)

View File

@ -6,9 +6,10 @@ import torch.nn
from .data import get_phone_symmap from .data import get_phone_symmap
from .engines import load_engines from .engines import load_engines
from .config import cfg from .config import cfg
from .models.lora import lora_get_state_dict
# stitches embeddings into one embedding + classifier => lm_head # stitches embeddings into one embedding & classifier => lm_head
def convert_to_hf( state_dict, config = None ): def convert_to_hf( state_dict, config = None, save_path = None ):
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1 n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
token_dim = 1024 token_dim = 1024
embedding = torch.nn.Embedding(n_tokens, token_dim) embedding = torch.nn.Embedding(n_tokens, token_dim)
@ -53,22 +54,46 @@ def convert_to_hf( state_dict, config = None ):
state_dict['module']['lm_head.weight'] = out_proj state_dict['module']['lm_head.weight'] = out_proj
del state_dict['module']['classifier.bias'] del state_dict['module']['classifier.bias']
torch.save(state_dict, "./data/export_test.pth") return state_dict
raise Exception("!") def extract_lora( state_dict, config = None, save_path = None ):
lora = state_dict["lora"] if "lora" in state_dict else None
# should always be included, but just in case
if lora is None and "module" in state_dict:
lora, module = lora_get_state_dict( state_dict["module"], split = True )
state_dict["module"] = module
state_dict["lora"] = lora
# should raise an exception since there's nothing to extract, or at least a warning
if not lora:
return state_dict
# save lora specifically
# should probably export other attributes, similar to what SD LoRAs do
save_path = save_path.parent / "lora.pth"
torch.save( { "module": lora }, save_path )
return state_dict return state_dict
def main(): def main():
parser = argparse.ArgumentParser("Save trained model to path.") parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--module-only", action='store_true') parser.add_argument("--module-only", action='store_true')
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
if args.module_only: if args.module_only:
cfg.trainer.load_module_only = True cfg.trainer.load_module_only = True
callback = convert_to_hf if args.hf else None callback = None
if args.hf:
callback = convert_to_hf
elif args.lora:
callback = extract_lora
if args.hf and args.lora:
raise Exception("Requesting more than one callback")
engines = load_engines() engines = load_engines()
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback) engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)

View File

@ -14,7 +14,7 @@ USE_PARAMETRIZATION = True
# LoRA Linear for replacement # LoRA Linear for replacement
# Pros: simple, just needs to reuse the replace_linear and copy weights # Pros: simple, just needs to reuse the replace_linear and copy weights
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc) # Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
class Linear(nn.Linear): class Linear(nn.Linear):
def __init__( def __init__(
self, self,