naive, rudimentary DeepSpeed support (just live with the LoRA weights living with the original weights, they can be split later)

This commit is contained in:
mrq 2024-06-17 13:17:24 -05:00
parent bd0bc10ec0
commit 726a4b613f
4 changed files with 37 additions and 10 deletions

View File

@ -8,7 +8,7 @@ if cfg.trainer.backend == "deepspeed":
elif cfg.trainer.backend == "local": elif cfg.trainer.backend == "local":
from .base import Engine from .base import Engine
from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
from ..models import get_models from ..models import get_models
from ..utils import wrapper as ml from ..utils import wrapper as ml
@ -40,7 +40,7 @@ def load_engines(training=True):
loads_state_dict = cfg.trainer.load_state_dict or inferencing loads_state_dict = cfg.trainer.load_state_dict or inferencing
ddp = cfg.trainer.ddp ddp = cfg.trainer.ddp
engine_class = _Engine if backend == "local" or inferencing else Engine engine_class = LocalEngine if backend == "local" or inferencing else Engine
if inferencing: if inferencing:
model.config.training = False model.config.training = False
@ -101,9 +101,9 @@ def load_engines(training=True):
warmup_steps = cfg.hyperparameters.warmup_steps warmup_steps = cfg.hyperparameters.warmup_steps
) )
"""
# set up our LR scheduler here # set up our LR scheduler here
"""
if inferencing: if inferencing:
optimizer = None optimizer = None
@ -113,7 +113,7 @@ def load_engines(training=True):
load_path = cfg.ckpt_dir / name / "fp32.pth" load_path = cfg.ckpt_dir / name / "fp32.pth"
if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists(): if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
print("DeepSpeed checkpoint missing, but weights found.") print("Checkpoint missing, but weights found.")
loads_state_dict = True loads_state_dict = True
stats = None stats = None
@ -161,7 +161,7 @@ def load_engines(training=True):
# deepspeed inferencing # deepspeed inferencing
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"): elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
engine_class = _Engine engine_class = LocalEngine
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
# use base engine if requested # use base engine if requested

View File

@ -79,7 +79,7 @@ class Engine():
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
# freeze non-LoRA params if requested # freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all: if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
return freeze_non_lora_weights( self.module ) return freeze_non_lora_weights( self.module )
for name, param in self.module.named_parameters(): for name, param in self.module.named_parameters():

View File

@ -66,6 +66,10 @@ class Engine(DeepSpeedEngine):
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"): if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
return freeze_non_lora_weights( self.module )
for name, param in self.module.named_parameters(): for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params): if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False) param.requires_grad_(False)
@ -108,8 +112,31 @@ class Engine(DeepSpeedEngine):
except Exception as e: except Exception as e:
print(str(e)) print(str(e))
# we'll just have to live with the LoRA weights living within our main weights
# they're easy to extract anyways
def load_checkpoint(self, load_dir, **kwargs ):
# override to load the lora instead
if cfg.lora is not None:
load_dir = cfg.ckpt_dir / cfg.lora.full_name
return super().load_checkpoint( load_dir, **kwargs )
def save_checkpoint(self, save_dir, **kwargs ):
# override to save the lora instead
if cfg.lora is not None:
save_dir = cfg.ckpt_dir / cfg.lora.full_name
return super().save_checkpoint( save_dir, **kwargs )
def load_loras( self ): def load_loras( self ):
... # apply lora weights
for lora in cfg.loras:
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy )
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
if lora_path.exists():
state_dict = torch.load(lora_path, map_location=torch.device(cfg.device))
self.module = lora_load_state_dict( self.module, state_dict )
def traverse(self, *args, **kwargs): def traverse(self, *args, **kwargs):
with ml.autocast(): with ml.autocast():

View File

@ -28,7 +28,7 @@ from .distributed import (
local_leader_only, local_leader_only,
) )
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder, load_engines from ..engines import Engine, Engines, TrainFeeder, default_feeder, load_engines
from .utils import to_device, do_gc, truncate_json from .utils import to_device, do_gc, truncate_json
from ..utils import wrapper as ml from ..utils import wrapper as ml