mechanism to store the model config inside the weights and load them, some other things to allow LoRA training on the RetNet (gradient checkpointing will gripe about inputs not having require_grad and nothing seems to remedy it)

This commit is contained in:
mrq 2024-07-16 18:23:13 -05:00
parent 3acc54df22
commit fe0f235335
6 changed files with 80 additions and 27 deletions

View File

@ -338,6 +338,9 @@ class Model:
if self.arch_type == "llama":
include = ["self_attn", "mlp"] # target only the attention + mlp
exclude = ["self_attn.k_proj"] # common literature says to ignore it
if self.arch_type == "retnet":
include = ["layers."] # target the core layers of the RetNet and ignore the auxiliary stuff
exclude = ["retention.k_proj"] # attention-based transformers ignore the K, so might as well ignore it for the retnet
return dict(include=include, exclude=exclude)
@ -585,6 +588,7 @@ class Trainer:
load_disabled_engines: bool = False
weight_dtype: str = "float16"
amp: bool = False
ddp: bool = False

View File

@ -10,7 +10,7 @@ elif cfg.trainer.backend == "local":
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
from ..models import get_models
from ..models import get_models, get_model
from ..utils import wrapper as ml
from ..models.lora import apply_lora
@ -32,23 +32,49 @@ def load_engines(training=True):
engines = dict()
for name, model in models.items():
state = None
stats = None
lora = None
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
loads_state_dict = cfg.trainer.load_state_dict or inferencing
checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = cfg.ckpt_dir / name / "fp32.pth"
# actually use the lora-specific checkpoint if available
if cfg.lora is not None:
lora = cfg.lora
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found.")
loads_state_dict = True
# load state early
if loads_state_dict:
state = torch.load(load_path, map_location=torch.device(cfg.device))
# check if config is defined in state, and re-initialize the model
if "config" in state:
print("Model config definition in weights, re-loading...")
config_state = state["config"]
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
hyper_config = model.config
optimizer = None
lr_scheduler = None
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
loads_state_dict = cfg.trainer.load_state_dict or inferencing
ddp = cfg.trainer.ddp
engine_class = LocalEngine if backend == "local" or inferencing else Engine
if inferencing:
model.config.training = False
# apply model replacers
if cfg.optimizations.replace and cfg.optimizations.linear:
model.model = ml.replace_linear( model.model )
@ -58,6 +84,9 @@ def load_engines(training=True):
for lora in cfg.loras:
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
if inferencing:
model.config.training = False
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
optimizer_class = None
scheduler_class = None
@ -116,22 +145,8 @@ def load_engines(training=True):
optimizer = None
lr_scheduler = None
checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = cfg.ckpt_dir / name / "fp32.pth"
# actually use the lora-specific checkpoint if available
if cfg.lora is not None:
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found.")
loads_state_dict = True
stats = None
# load state dict if requested / required
if loads_state_dict:
state = torch.load(load_path, map_location=torch.device(cfg.device))
# state dict is not just the module, extract the extra trainer details
if "stats" in state:
stats = state["stats"]

View File

@ -320,11 +320,21 @@ class Engines(dict[str, Engine]):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def export(self, userdata={}, callback=None):
def export(self, userdata={}, callback=None, dtype=None):
if dtype is None:
dtype = cfg.trainer.dtype
for name, engine in self.items():
module = engine.module.state_dict()
lora = None
save_path = cfg.ckpt_dir / name / "fp32.pth"
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
if not isinstance(config, dict):
config = config.__dict__
# safety
for k, v in module.items():
module[k] = v.to(dtype)
if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True )
@ -339,7 +349,8 @@ class Engines(dict[str, Engine]):
"global_samples": engine.global_samples,
"tokens_processed": engine.tokens_processed,
},
"userdata": userdata
"userdata": userdata,
"config": config
}
if callback:
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )

View File

@ -56,7 +56,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
return state_dict
def extract_lora( state_dict, config = None, save_path = None ):
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
if dtype is None:
dtype = cfg.inference.dtype
lora = state_dict["lora"] if "lora" in state_dict else None
# should always be included, but just in case
if lora is None and "module" in state_dict:
@ -71,7 +74,10 @@ def extract_lora( state_dict, config = None, save_path = None ):
# save lora specifically
# should probably export other attributes, similar to what SD LoRAs do
save_path = save_path.parent / "lora.pth"
torch.save( { "module": lora }, save_path )
torch.save( {
"module": lora,
"config": cfg.lora.__dict__ if cfg.lora is not None else None,
}, save_path )
return state_dict
@ -81,6 +87,7 @@ def main():
parser.add_argument("--module-only", action='store_true')
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
args, unknown = parser.parse_known_args()
if args.module_only:
@ -95,7 +102,10 @@ def main():
if args.hf and args.lora:
raise Exception("Requesting more than one callback")
engines = load_engines()
if args.dtype != "auto":
cfg.trainer.weight_dtype = args.dtype
engines = load_engines(training=False)
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
if __name__ == "__main__":

View File

@ -581,6 +581,18 @@ class Base(nn.Module):
))
self.model = RetNetDecoder(RetNetConfig(**kwargs))
# do some funny stuff for LoRA training
"""
if self.gradient_checkpointing:
def make_inputs_require_grads(module, input, output):
for i, t in enumerate(input):
if not isinstance(t, torch.Tensor):
continue
t.requires_grad_(True)
self.model.register_forward_hook(make_inputs_require_grads)
"""
elif self.arch_type == "retnet-hf":
kwargs = dict(
vocab_size=n_resp_tokens,
@ -713,7 +725,7 @@ class Base(nn.Module):
x = inputs
m = mask.squeeze(-1).int()
aux_loss = None
# HF transformer derived model
if self.arch_type in ["llama", "mistral", "mixtral"]:
kwargs = dict(

View File

@ -148,6 +148,7 @@ class ParameterizedLoRA(nn.Module):
def passes_policy( policy, name ):
if policy is None:
return True
if "exclude" in policy:
for term in policy["exclude"]:
if term in name: