mechanism to store the model config inside the weights and load them, some other things to allow LoRA training on the RetNet (gradient checkpointing will gripe about inputs not having require_grad and nothing seems to remedy it)

This commit is contained in:
mrq 2024-07-16 18:23:13 -05:00
parent 3acc54df22
commit fe0f235335
6 changed files with 80 additions and 27 deletions

View File

@ -338,6 +338,9 @@ class Model:
if self.arch_type == "llama": if self.arch_type == "llama":
include = ["self_attn", "mlp"] # target only the attention + mlp include = ["self_attn", "mlp"] # target only the attention + mlp
exclude = ["self_attn.k_proj"] # common literature says to ignore it exclude = ["self_attn.k_proj"] # common literature says to ignore it
if self.arch_type == "retnet":
include = ["layers."] # target the core layers of the RetNet and ignore the auxiliary stuff
exclude = ["retention.k_proj"] # attention-based transformers ignore the K, so might as well ignore it for the retnet
return dict(include=include, exclude=exclude) return dict(include=include, exclude=exclude)
@ -585,6 +588,7 @@ class Trainer:
load_disabled_engines: bool = False load_disabled_engines: bool = False
weight_dtype: str = "float16" weight_dtype: str = "float16"
amp: bool = False amp: bool = False
ddp: bool = False ddp: bool = False

View File

@ -10,7 +10,7 @@ elif cfg.trainer.backend == "local":
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
from ..models import get_models from ..models import get_models, get_model
from ..utils import wrapper as ml from ..utils import wrapper as ml
from ..models.lora import apply_lora from ..models.lora import apply_lora
@ -32,23 +32,49 @@ def load_engines(training=True):
engines = dict() engines = dict()
for name, model in models.items(): for name, model in models.items():
state = None
stats = None
lora = None
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
loads_state_dict = cfg.trainer.load_state_dict or inferencing
checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = cfg.ckpt_dir / name / "fp32.pth"
# actually use the lora-specific checkpoint if available
if cfg.lora is not None:
lora = cfg.lora
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found.")
loads_state_dict = True
# load state early
if loads_state_dict:
state = torch.load(load_path, map_location=torch.device(cfg.device))
# check if config is defined in state, and re-initialize the model
if "config" in state:
print("Model config definition in weights, re-loading...")
config_state = state["config"]
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
hyper_config = model.config hyper_config = model.config
optimizer = None optimizer = None
lr_scheduler = None lr_scheduler = None
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
backend = cfg.inference.backend if inferencing else cfg.trainer.backend backend = cfg.inference.backend if inferencing else cfg.trainer.backend
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp amp = cfg.inference.amp if inferencing else cfg.trainer.amp
loads_state_dict = cfg.trainer.load_state_dict or inferencing
ddp = cfg.trainer.ddp ddp = cfg.trainer.ddp
engine_class = LocalEngine if backend == "local" or inferencing else Engine engine_class = LocalEngine if backend == "local" or inferencing else Engine
if inferencing: # apply model replacers
model.config.training = False
if cfg.optimizations.replace and cfg.optimizations.linear: if cfg.optimizations.replace and cfg.optimizations.linear:
model.model = ml.replace_linear( model.model ) model.model = ml.replace_linear( model.model )
@ -58,6 +84,9 @@ def load_engines(training=True):
for lora in cfg.loras: for lora in cfg.loras:
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize ) model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
if inferencing:
model.config.training = False
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)): if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
optimizer_class = None optimizer_class = None
scheduler_class = None scheduler_class = None
@ -116,22 +145,8 @@ def load_engines(training=True):
optimizer = None optimizer = None
lr_scheduler = None lr_scheduler = None
checkpoint_path = cfg.ckpt_dir / name / "latest" # load state dict if requested / required
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = cfg.ckpt_dir / name / "fp32.pth"
# actually use the lora-specific checkpoint if available
if cfg.lora is not None:
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found.")
loads_state_dict = True
stats = None
if loads_state_dict: if loads_state_dict:
state = torch.load(load_path, map_location=torch.device(cfg.device))
# state dict is not just the module, extract the extra trainer details # state dict is not just the module, extract the extra trainer details
if "stats" in state: if "stats" in state:
stats = state["stats"] stats = state["stats"]

View File

@ -320,11 +320,21 @@ class Engines(dict[str, Engine]):
for engine in self.values(): for engine in self.values():
engine.dispatch_attribute(*args, **kwargs) engine.dispatch_attribute(*args, **kwargs)
def export(self, userdata={}, callback=None): def export(self, userdata={}, callback=None, dtype=None):
if dtype is None:
dtype = cfg.trainer.dtype
for name, engine in self.items(): for name, engine in self.items():
module = engine.module.state_dict() module = engine.module.state_dict()
lora = None lora = None
save_path = cfg.ckpt_dir / name / "fp32.pth" save_path = cfg.ckpt_dir / name / "fp32.pth"
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
if not isinstance(config, dict):
config = config.__dict__
# safety
for k, v in module.items():
module[k] = v.to(dtype)
if cfg.lora is not None: if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True ) lora, module = lora_get_state_dict( module, split = True )
@ -339,7 +349,8 @@ class Engines(dict[str, Engine]):
"global_samples": engine.global_samples, "global_samples": engine.global_samples,
"tokens_processed": engine.tokens_processed, "tokens_processed": engine.tokens_processed,
}, },
"userdata": userdata "userdata": userdata,
"config": config
} }
if callback: if callback:
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path ) state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )

View File

@ -56,7 +56,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
return state_dict return state_dict
def extract_lora( state_dict, config = None, save_path = None ): def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
if dtype is None:
dtype = cfg.inference.dtype
lora = state_dict["lora"] if "lora" in state_dict else None lora = state_dict["lora"] if "lora" in state_dict else None
# should always be included, but just in case # should always be included, but just in case
if lora is None and "module" in state_dict: if lora is None and "module" in state_dict:
@ -71,7 +74,10 @@ def extract_lora( state_dict, config = None, save_path = None ):
# save lora specifically # save lora specifically
# should probably export other attributes, similar to what SD LoRAs do # should probably export other attributes, similar to what SD LoRAs do
save_path = save_path.parent / "lora.pth" save_path = save_path.parent / "lora.pth"
torch.save( { "module": lora }, save_path ) torch.save( {
"module": lora,
"config": cfg.lora.__dict__ if cfg.lora is not None else None,
}, save_path )
return state_dict return state_dict
@ -81,6 +87,7 @@ def main():
parser.add_argument("--module-only", action='store_true') parser.add_argument("--module-only", action='store_true')
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
if args.module_only: if args.module_only:
@ -95,7 +102,10 @@ def main():
if args.hf and args.lora: if args.hf and args.lora:
raise Exception("Requesting more than one callback") raise Exception("Requesting more than one callback")
engines = load_engines() if args.dtype != "auto":
cfg.trainer.weight_dtype = args.dtype
engines = load_engines(training=False)
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback) engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -581,6 +581,18 @@ class Base(nn.Module):
)) ))
self.model = RetNetDecoder(RetNetConfig(**kwargs)) self.model = RetNetDecoder(RetNetConfig(**kwargs))
# do some funny stuff for LoRA training
"""
if self.gradient_checkpointing:
def make_inputs_require_grads(module, input, output):
for i, t in enumerate(input):
if not isinstance(t, torch.Tensor):
continue
t.requires_grad_(True)
self.model.register_forward_hook(make_inputs_require_grads)
"""
elif self.arch_type == "retnet-hf": elif self.arch_type == "retnet-hf":
kwargs = dict( kwargs = dict(
vocab_size=n_resp_tokens, vocab_size=n_resp_tokens,

View File

@ -148,6 +148,7 @@ class ParameterizedLoRA(nn.Module):
def passes_policy( policy, name ): def passes_policy( policy, name ):
if policy is None: if policy is None:
return True return True
if "exclude" in policy: if "exclude" in policy:
for term in policy["exclude"]: for term in policy["exclude"]:
if term in name: if term in name: