actually make deepspeed work with LoRAs
This commit is contained in:
parent
1d159b1476
commit
7047fcc6e2
|
@ -12,6 +12,8 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
||||||
|
|
||||||
from ..models import get_models
|
from ..models import get_models
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
|
from ..models.lora import apply_lora
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
@ -30,6 +32,8 @@ def load_engines(training=True):
|
||||||
engines = dict()
|
engines = dict()
|
||||||
|
|
||||||
for name, model in models.items():
|
for name, model in models.items():
|
||||||
|
hyper_config = model.config
|
||||||
|
|
||||||
optimizer = None
|
optimizer = None
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
|
@ -51,6 +55,9 @@ def load_engines(training=True):
|
||||||
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
||||||
model.model = ml.replace_embedding( model.model )
|
model.model = ml.replace_embedding( model.model )
|
||||||
|
|
||||||
|
for lora in cfg.loras:
|
||||||
|
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy )
|
||||||
|
|
||||||
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||||
optimizer_class = None
|
optimizer_class = None
|
||||||
scheduler_class = None
|
scheduler_class = None
|
||||||
|
@ -153,8 +160,6 @@ def load_engines(training=True):
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
hyper_config = model.config
|
|
||||||
|
|
||||||
# wrap if DDP is requested
|
# wrap if DDP is requested
|
||||||
if ddp:
|
if ddp:
|
||||||
model = ddp_model(model)
|
model = ddp_model(model)
|
||||||
|
@ -178,9 +183,6 @@ def load_engines(training=True):
|
||||||
engines = Engines(engines)
|
engines = Engines(engines)
|
||||||
engines.setup()
|
engines.setup()
|
||||||
|
|
||||||
for name, engine in engines.items():
|
|
||||||
engine.load_loras()
|
|
||||||
|
|
||||||
if not cfg.trainer.load_state_dict:
|
if not cfg.trainer.load_state_dict:
|
||||||
engines.load_checkpoint()
|
engines.load_checkpoint()
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ def default_feeder(engine, batch):
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||||
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
|
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
|
||||||
from ..models.lora import apply_lora, freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
|
from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
@ -190,17 +190,6 @@ class Engine():
|
||||||
if 'lora' in state:
|
if 'lora' in state:
|
||||||
lora_load_state_dict( self.module, state['lora'] )
|
lora_load_state_dict( self.module, state['lora'] )
|
||||||
|
|
||||||
|
|
||||||
def load_loras( self ):
|
|
||||||
# apply lora weights
|
|
||||||
for lora in cfg.loras:
|
|
||||||
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy )
|
|
||||||
|
|
||||||
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
|
|
||||||
if lora_path.exists():
|
|
||||||
state_dict = torch.load(lora_path, map_location=torch.device(cfg.device))
|
|
||||||
self.module = lora_load_state_dict( self.module, state_dict )
|
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
return self.module.eval()
|
return self.module.eval()
|
||||||
|
|
||||||
|
@ -334,11 +323,11 @@ class Engines(dict[str, Engine]):
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
module = engine.module.state_dict()
|
module = engine.module.state_dict()
|
||||||
lora = None
|
lora = None
|
||||||
save_dir = cfg.ckpt_dir / name / "fp32.pth"
|
save_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
|
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
lora, module = lora_get_state_dict( module, split = True )
|
lora, module = lora_get_state_dict( module, split = True )
|
||||||
save_dir = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth"
|
save_path = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth"
|
||||||
|
|
||||||
state_dict = {
|
state_dict = {
|
||||||
'module': module,
|
'module': module,
|
||||||
|
@ -352,9 +341,9 @@ class Engines(dict[str, Engine]):
|
||||||
"userdata": userdata
|
"userdata": userdata
|
||||||
}
|
}
|
||||||
if callback:
|
if callback:
|
||||||
state_dict = callback( state_dict, engine.hyper_config )
|
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
||||||
|
|
||||||
torch.save(state_dict, save_dir)
|
torch.save(state_dict, save_path)
|
||||||
print(f"Exported {name} to {outpath}")
|
print(f"Exported {name} to {outpath}")
|
||||||
|
|
||||||
def save_checkpoint(self, tag=None):
|
def save_checkpoint(self, tag=None):
|
||||||
|
|
|
@ -68,7 +68,12 @@ class Engine(DeepSpeedEngine):
|
||||||
|
|
||||||
# freeze non-LoRA params if requested
|
# freeze non-LoRA params if requested
|
||||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||||
return freeze_non_lora_weights( self.module )
|
for name, param in self.module.named_parameters():
|
||||||
|
should = 'lora_' in name
|
||||||
|
param.requires_grad_(should)
|
||||||
|
if not should:
|
||||||
|
self._frozen_params.add(param)
|
||||||
|
return
|
||||||
|
|
||||||
for name, param in self.module.named_parameters():
|
for name, param in self.module.named_parameters():
|
||||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
||||||
|
@ -128,16 +133,6 @@ class Engine(DeepSpeedEngine):
|
||||||
|
|
||||||
return super().save_checkpoint( save_dir, **kwargs )
|
return super().save_checkpoint( save_dir, **kwargs )
|
||||||
|
|
||||||
def load_loras( self ):
|
|
||||||
# apply lora weights
|
|
||||||
for lora in cfg.loras:
|
|
||||||
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy )
|
|
||||||
|
|
||||||
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
|
|
||||||
if lora_path.exists():
|
|
||||||
state_dict = torch.load(lora_path, map_location=torch.device(cfg.device))
|
|
||||||
self.module = lora_load_state_dict( self.module, state_dict )
|
|
||||||
|
|
||||||
def traverse(self, *args, **kwargs):
|
def traverse(self, *args, **kwargs):
|
||||||
with ml.autocast():
|
with ml.autocast():
|
||||||
self.forward(*args, **kwargs)
|
self.forward(*args, **kwargs)
|
||||||
|
|
|
@ -6,9 +6,10 @@ import torch.nn
|
||||||
from .data import get_phone_symmap
|
from .data import get_phone_symmap
|
||||||
from .engines import load_engines
|
from .engines import load_engines
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
|
from .models.lora import lora_get_state_dict
|
||||||
|
|
||||||
# stitches embeddings into one embedding + classifier => lm_head
|
# stitches embeddings into one embedding & classifier => lm_head
|
||||||
def convert_to_hf( state_dict, config = None ):
|
def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
|
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
|
||||||
token_dim = 1024
|
token_dim = 1024
|
||||||
embedding = torch.nn.Embedding(n_tokens, token_dim)
|
embedding = torch.nn.Embedding(n_tokens, token_dim)
|
||||||
|
@ -53,22 +54,46 @@ def convert_to_hf( state_dict, config = None ):
|
||||||
state_dict['module']['lm_head.weight'] = out_proj
|
state_dict['module']['lm_head.weight'] = out_proj
|
||||||
del state_dict['module']['classifier.bias']
|
del state_dict['module']['classifier.bias']
|
||||||
|
|
||||||
torch.save(state_dict, "./data/export_test.pth")
|
return state_dict
|
||||||
|
|
||||||
raise Exception("!")
|
def extract_lora( state_dict, config = None, save_path = None ):
|
||||||
|
lora = state_dict["lora"] if "lora" in state_dict else None
|
||||||
|
# should always be included, but just in case
|
||||||
|
if lora is None and "module" in state_dict:
|
||||||
|
lora, module = lora_get_state_dict( state_dict["module"], split = True )
|
||||||
|
state_dict["module"] = module
|
||||||
|
state_dict["lora"] = lora
|
||||||
|
|
||||||
|
# should raise an exception since there's nothing to extract, or at least a warning
|
||||||
|
if not lora:
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
# save lora specifically
|
||||||
|
# should probably export other attributes, similar to what SD LoRAs do
|
||||||
|
save_path = save_path.parent / "lora.pth"
|
||||||
|
torch.save( { "module": lora }, save_path )
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||||
parser.add_argument("--module-only", action='store_true')
|
parser.add_argument("--module-only", action='store_true')
|
||||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
||||||
|
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
if args.module_only:
|
if args.module_only:
|
||||||
cfg.trainer.load_module_only = True
|
cfg.trainer.load_module_only = True
|
||||||
|
|
||||||
callback = convert_to_hf if args.hf else None
|
callback = None
|
||||||
|
if args.hf:
|
||||||
|
callback = convert_to_hf
|
||||||
|
elif args.lora:
|
||||||
|
callback = extract_lora
|
||||||
|
|
||||||
|
if args.hf and args.lora:
|
||||||
|
raise Exception("Requesting more than one callback")
|
||||||
|
|
||||||
engines = load_engines()
|
engines = load_engines()
|
||||||
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
||||||
|
|
|
@ -14,7 +14,7 @@ USE_PARAMETRIZATION = True
|
||||||
|
|
||||||
# LoRA Linear for replacement
|
# LoRA Linear for replacement
|
||||||
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
||||||
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc)
|
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
|
||||||
class Linear(nn.Linear):
|
class Linear(nn.Linear):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user