naive, rudimentary DeepSpeed support (just live with the LoRA weights living with the original weights, they can be split later)
This commit is contained in:
parent
bd0bc10ec0
commit
726a4b613f
|
@ -8,7 +8,7 @@ if cfg.trainer.backend == "deepspeed":
|
||||||
elif cfg.trainer.backend == "local":
|
elif cfg.trainer.backend == "local":
|
||||||
from .base import Engine
|
from .base import Engine
|
||||||
|
|
||||||
from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine
|
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
||||||
|
|
||||||
from ..models import get_models
|
from ..models import get_models
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
|
@ -40,7 +40,7 @@ def load_engines(training=True):
|
||||||
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
||||||
ddp = cfg.trainer.ddp
|
ddp = cfg.trainer.ddp
|
||||||
|
|
||||||
engine_class = _Engine if backend == "local" or inferencing else Engine
|
engine_class = LocalEngine if backend == "local" or inferencing else Engine
|
||||||
|
|
||||||
if inferencing:
|
if inferencing:
|
||||||
model.config.training = False
|
model.config.training = False
|
||||||
|
@ -101,9 +101,9 @@ def load_engines(training=True):
|
||||||
warmup_steps = cfg.hyperparameters.warmup_steps
|
warmup_steps = cfg.hyperparameters.warmup_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
# set up our LR scheduler here
|
# set up our LR scheduler here
|
||||||
|
"""
|
||||||
|
|
||||||
if inferencing:
|
if inferencing:
|
||||||
optimizer = None
|
optimizer = None
|
||||||
|
@ -113,7 +113,7 @@ def load_engines(training=True):
|
||||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
|
|
||||||
if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
|
if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
|
||||||
print("DeepSpeed checkpoint missing, but weights found.")
|
print("Checkpoint missing, but weights found.")
|
||||||
loads_state_dict = True
|
loads_state_dict = True
|
||||||
|
|
||||||
stats = None
|
stats = None
|
||||||
|
@ -161,7 +161,7 @@ def load_engines(training=True):
|
||||||
|
|
||||||
# deepspeed inferencing
|
# deepspeed inferencing
|
||||||
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
||||||
engine_class = _Engine
|
engine_class = LocalEngine
|
||||||
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||||
|
|
||||||
# use base engine if requested
|
# use base engine if requested
|
||||||
|
|
|
@ -79,7 +79,7 @@ class Engine():
|
||||||
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
||||||
|
|
||||||
# freeze non-LoRA params if requested
|
# freeze non-LoRA params if requested
|
||||||
if not self.hyper_config.frozen_params and not freeze_all:
|
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||||
return freeze_non_lora_weights( self.module )
|
return freeze_non_lora_weights( self.module )
|
||||||
|
|
||||||
for name, param in self.module.named_parameters():
|
for name, param in self.module.named_parameters():
|
||||||
|
|
|
@ -66,6 +66,10 @@ class Engine(DeepSpeedEngine):
|
||||||
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
||||||
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
||||||
|
|
||||||
|
# freeze non-LoRA params if requested
|
||||||
|
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||||
|
return freeze_non_lora_weights( self.module )
|
||||||
|
|
||||||
for name, param in self.module.named_parameters():
|
for name, param in self.module.named_parameters():
|
||||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
@ -108,8 +112,31 @@ class Engine(DeepSpeedEngine):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(str(e))
|
print(str(e))
|
||||||
|
|
||||||
|
# we'll just have to live with the LoRA weights living within our main weights
|
||||||
|
# they're easy to extract anyways
|
||||||
|
def load_checkpoint(self, load_dir, **kwargs ):
|
||||||
|
# override to load the lora instead
|
||||||
|
if cfg.lora is not None:
|
||||||
|
load_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||||
|
|
||||||
|
return super().load_checkpoint( load_dir, **kwargs )
|
||||||
|
|
||||||
|
def save_checkpoint(self, save_dir, **kwargs ):
|
||||||
|
# override to save the lora instead
|
||||||
|
if cfg.lora is not None:
|
||||||
|
save_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||||
|
|
||||||
|
return super().save_checkpoint( save_dir, **kwargs )
|
||||||
|
|
||||||
def load_loras( self ):
|
def load_loras( self ):
|
||||||
...
|
# apply lora weights
|
||||||
|
for lora in cfg.loras:
|
||||||
|
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy )
|
||||||
|
|
||||||
|
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
|
||||||
|
if lora_path.exists():
|
||||||
|
state_dict = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||||
|
self.module = lora_load_state_dict( self.module, state_dict )
|
||||||
|
|
||||||
def traverse(self, *args, **kwargs):
|
def traverse(self, *args, **kwargs):
|
||||||
with ml.autocast():
|
with ml.autocast():
|
||||||
|
|
|
@ -28,7 +28,7 @@ from .distributed import (
|
||||||
local_leader_only,
|
local_leader_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder, load_engines
|
from ..engines import Engine, Engines, TrainFeeder, default_feeder, load_engines
|
||||||
|
|
||||||
from .utils import to_device, do_gc, truncate_json
|
from .utils import to_device, do_gc, truncate_json
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
|
|
Loading…
Reference in New Issue
Block a user