naive, rudimentary DeepSpeed support (just live with the LoRA weights living with the original weights, they can be split later)
This commit is contained in:
parent
bd0bc10ec0
commit
726a4b613f
|
@ -8,7 +8,7 @@ if cfg.trainer.backend == "deepspeed":
|
|||
elif cfg.trainer.backend == "local":
|
||||
from .base import Engine
|
||||
|
||||
from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine
|
||||
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
||||
|
||||
from ..models import get_models
|
||||
from ..utils import wrapper as ml
|
||||
|
@ -40,7 +40,7 @@ def load_engines(training=True):
|
|||
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
||||
ddp = cfg.trainer.ddp
|
||||
|
||||
engine_class = _Engine if backend == "local" or inferencing else Engine
|
||||
engine_class = LocalEngine if backend == "local" or inferencing else Engine
|
||||
|
||||
if inferencing:
|
||||
model.config.training = False
|
||||
|
@ -101,9 +101,9 @@ def load_engines(training=True):
|
|||
warmup_steps = cfg.hyperparameters.warmup_steps
|
||||
)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
# set up our LR scheduler here
|
||||
"""
|
||||
|
||||
if inferencing:
|
||||
optimizer = None
|
||||
|
@ -113,7 +113,7 @@ def load_engines(training=True):
|
|||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
|
||||
if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
|
||||
print("DeepSpeed checkpoint missing, but weights found.")
|
||||
print("Checkpoint missing, but weights found.")
|
||||
loads_state_dict = True
|
||||
|
||||
stats = None
|
||||
|
@ -161,7 +161,7 @@ def load_engines(training=True):
|
|||
|
||||
# deepspeed inferencing
|
||||
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
||||
engine_class = _Engine
|
||||
engine_class = LocalEngine
|
||||
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||
|
||||
# use base engine if requested
|
||||
|
|
|
@ -79,7 +79,7 @@ class Engine():
|
|||
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
||||
|
||||
# freeze non-LoRA params if requested
|
||||
if not self.hyper_config.frozen_params and not freeze_all:
|
||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||
return freeze_non_lora_weights( self.module )
|
||||
|
||||
for name, param in self.module.named_parameters():
|
||||
|
|
|
@ -66,6 +66,10 @@ class Engine(DeepSpeedEngine):
|
|||
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
||||
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
||||
|
||||
# freeze non-LoRA params if requested
|
||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||
return freeze_non_lora_weights( self.module )
|
||||
|
||||
for name, param in self.module.named_parameters():
|
||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
||||
param.requires_grad_(False)
|
||||
|
@ -108,8 +112,31 @@ class Engine(DeepSpeedEngine):
|
|||
except Exception as e:
|
||||
print(str(e))
|
||||
|
||||
def load_loras(self):
|
||||
...
|
||||
# we'll just have to live with the LoRA weights living within our main weights
|
||||
# they're easy to extract anyways
|
||||
def load_checkpoint(self, load_dir, **kwargs ):
|
||||
# override to load the lora instead
|
||||
if cfg.lora is not None:
|
||||
load_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||
|
||||
return super().load_checkpoint( load_dir, **kwargs )
|
||||
|
||||
def save_checkpoint(self, save_dir, **kwargs ):
|
||||
# override to save the lora instead
|
||||
if cfg.lora is not None:
|
||||
save_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||
|
||||
return super().save_checkpoint( save_dir, **kwargs )
|
||||
|
||||
def load_loras( self ):
|
||||
# apply lora weights
|
||||
for lora in cfg.loras:
|
||||
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy )
|
||||
|
||||
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
|
||||
if lora_path.exists():
|
||||
state_dict = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||
self.module = lora_load_state_dict( self.module, state_dict )
|
||||
|
||||
def traverse(self, *args, **kwargs):
|
||||
with ml.autocast():
|
||||
|
|
|
@ -28,7 +28,7 @@ from .distributed import (
|
|||
local_leader_only,
|
||||
)
|
||||
|
||||
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder, load_engines
|
||||
from ..engines import Engine, Engines, TrainFeeder, default_feeder, load_engines
|
||||
|
||||
from .utils import to_device, do_gc, truncate_json
|
||||
from ..utils import wrapper as ml
|
||||
|
|
Loading…
Reference in New Issue
Block a user