naive, rudimentary DeepSpeed support (just live with the LoRA weights living with the original weights, they can be split later)

This commit is contained in:
mrq 2024-06-17 13:17:24 -05:00
parent bd0bc10ec0
commit 726a4b613f
4 changed files with 37 additions and 10 deletions

View File

@ -8,7 +8,7 @@ if cfg.trainer.backend == "deepspeed":
elif cfg.trainer.backend == "local":
from .base import Engine
from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
from ..models import get_models
from ..utils import wrapper as ml
@ -40,7 +40,7 @@ def load_engines(training=True):
loads_state_dict = cfg.trainer.load_state_dict or inferencing
ddp = cfg.trainer.ddp
engine_class = _Engine if backend == "local" or inferencing else Engine
engine_class = LocalEngine if backend == "local" or inferencing else Engine
if inferencing:
model.config.training = False
@ -101,9 +101,9 @@ def load_engines(training=True):
warmup_steps = cfg.hyperparameters.warmup_steps
)
"""
# set up our LR scheduler here
"""
if inferencing:
optimizer = None
@ -113,7 +113,7 @@ def load_engines(training=True):
load_path = cfg.ckpt_dir / name / "fp32.pth"
if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
print("DeepSpeed checkpoint missing, but weights found.")
print("Checkpoint missing, but weights found.")
loads_state_dict = True
stats = None
@ -161,7 +161,7 @@ def load_engines(training=True):
# deepspeed inferencing
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
engine_class = _Engine
engine_class = LocalEngine
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
# use base engine if requested

View File

@ -79,7 +79,7 @@ class Engine():
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all:
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
return freeze_non_lora_weights( self.module )
for name, param in self.module.named_parameters():

View File

@ -66,6 +66,10 @@ class Engine(DeepSpeedEngine):
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
return freeze_non_lora_weights( self.module )
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False)
@ -108,8 +112,31 @@ class Engine(DeepSpeedEngine):
except Exception as e:
print(str(e))
def load_loras(self):
...
# we'll just have to live with the LoRA weights living within our main weights
# they're easy to extract anyways
def load_checkpoint(self, load_dir, **kwargs ):
# override to load the lora instead
if cfg.lora is not None:
load_dir = cfg.ckpt_dir / cfg.lora.full_name
return super().load_checkpoint( load_dir, **kwargs )
def save_checkpoint(self, save_dir, **kwargs ):
# override to save the lora instead
if cfg.lora is not None:
save_dir = cfg.ckpt_dir / cfg.lora.full_name
return super().save_checkpoint( save_dir, **kwargs )
def load_loras( self ):
# apply lora weights
for lora in cfg.loras:
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy )
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
if lora_path.exists():
state_dict = torch.load(lora_path, map_location=torch.device(cfg.device))
self.module = lora_load_state_dict( self.module, state_dict )
def traverse(self, *args, **kwargs):
with ml.autocast():

View File

@ -28,7 +28,7 @@ from .distributed import (
local_leader_only,
)
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder, load_engines
from ..engines import Engine, Engines, TrainFeeder, default_feeder, load_engines
from .utils import to_device, do_gc, truncate_json
from ..utils import wrapper as ml