mechanism to store the model config inside the weights and load them, some other things to allow LoRA training on the RetNet (gradient checkpointing will gripe about inputs not having require_grad and nothing seems to remedy it)
This commit is contained in:
parent
3acc54df22
commit
fe0f235335
|
@ -338,6 +338,9 @@ class Model:
|
|||
if self.arch_type == "llama":
|
||||
include = ["self_attn", "mlp"] # target only the attention + mlp
|
||||
exclude = ["self_attn.k_proj"] # common literature says to ignore it
|
||||
if self.arch_type == "retnet":
|
||||
include = ["layers."] # target the core layers of the RetNet and ignore the auxiliary stuff
|
||||
exclude = ["retention.k_proj"] # attention-based transformers ignore the K, so might as well ignore it for the retnet
|
||||
|
||||
return dict(include=include, exclude=exclude)
|
||||
|
||||
|
@ -585,6 +588,7 @@ class Trainer:
|
|||
load_disabled_engines: bool = False
|
||||
|
||||
weight_dtype: str = "float16"
|
||||
|
||||
amp: bool = False
|
||||
ddp: bool = False
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ elif cfg.trainer.backend == "local":
|
|||
|
||||
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
||||
|
||||
from ..models import get_models
|
||||
from ..models import get_models, get_model
|
||||
from ..utils import wrapper as ml
|
||||
from ..models.lora import apply_lora
|
||||
|
||||
|
@ -32,23 +32,49 @@ def load_engines(training=True):
|
|||
engines = dict()
|
||||
|
||||
for name, model in models.items():
|
||||
state = None
|
||||
stats = None
|
||||
lora = None
|
||||
|
||||
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
|
||||
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
||||
|
||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
|
||||
# actually use the lora-specific checkpoint if available
|
||||
if cfg.lora is not None:
|
||||
lora = cfg.lora
|
||||
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
print("Checkpoint missing, but weights found.")
|
||||
loads_state_dict = True
|
||||
|
||||
# load state early
|
||||
if loads_state_dict:
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
|
||||
# check if config is defined in state, and re-initialize the model
|
||||
if "config" in state:
|
||||
print("Model config definition in weights, re-loading...")
|
||||
config_state = state["config"]
|
||||
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
|
||||
|
||||
hyper_config = model.config
|
||||
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
|
||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
||||
ddp = cfg.trainer.ddp
|
||||
|
||||
engine_class = LocalEngine if backend == "local" or inferencing else Engine
|
||||
|
||||
if inferencing:
|
||||
model.config.training = False
|
||||
|
||||
# apply model replacers
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
model.model = ml.replace_linear( model.model )
|
||||
|
||||
|
@ -58,6 +84,9 @@ def load_engines(training=True):
|
|||
for lora in cfg.loras:
|
||||
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
|
||||
|
||||
if inferencing:
|
||||
model.config.training = False
|
||||
|
||||
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
|
||||
optimizer_class = None
|
||||
scheduler_class = None
|
||||
|
@ -116,22 +145,8 @@ def load_engines(training=True):
|
|||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
|
||||
# actually use the lora-specific checkpoint if available
|
||||
if cfg.lora is not None:
|
||||
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
print("Checkpoint missing, but weights found.")
|
||||
loads_state_dict = True
|
||||
|
||||
stats = None
|
||||
# load state dict if requested / required
|
||||
if loads_state_dict:
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
|
||||
# state dict is not just the module, extract the extra trainer details
|
||||
if "stats" in state:
|
||||
stats = state["stats"]
|
||||
|
|
|
@ -320,11 +320,21 @@ class Engines(dict[str, Engine]):
|
|||
for engine in self.values():
|
||||
engine.dispatch_attribute(*args, **kwargs)
|
||||
|
||||
def export(self, userdata={}, callback=None):
|
||||
def export(self, userdata={}, callback=None, dtype=None):
|
||||
if dtype is None:
|
||||
dtype = cfg.trainer.dtype
|
||||
|
||||
for name, engine in self.items():
|
||||
module = engine.module.state_dict()
|
||||
lora = None
|
||||
save_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
|
||||
if not isinstance(config, dict):
|
||||
config = config.__dict__
|
||||
|
||||
# safety
|
||||
for k, v in module.items():
|
||||
module[k] = v.to(dtype)
|
||||
|
||||
if cfg.lora is not None:
|
||||
lora, module = lora_get_state_dict( module, split = True )
|
||||
|
@ -339,7 +349,8 @@ class Engines(dict[str, Engine]):
|
|||
"global_samples": engine.global_samples,
|
||||
"tokens_processed": engine.tokens_processed,
|
||||
},
|
||||
"userdata": userdata
|
||||
"userdata": userdata,
|
||||
"config": config
|
||||
}
|
||||
if callback:
|
||||
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
||||
|
|
|
@ -56,7 +56,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
|
||||
return state_dict
|
||||
|
||||
def extract_lora( state_dict, config = None, save_path = None ):
|
||||
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
||||
if dtype is None:
|
||||
dtype = cfg.inference.dtype
|
||||
|
||||
lora = state_dict["lora"] if "lora" in state_dict else None
|
||||
# should always be included, but just in case
|
||||
if lora is None and "module" in state_dict:
|
||||
|
@ -71,7 +74,10 @@ def extract_lora( state_dict, config = None, save_path = None ):
|
|||
# save lora specifically
|
||||
# should probably export other attributes, similar to what SD LoRAs do
|
||||
save_path = save_path.parent / "lora.pth"
|
||||
torch.save( { "module": lora }, save_path )
|
||||
torch.save( {
|
||||
"module": lora,
|
||||
"config": cfg.lora.__dict__ if cfg.lora is not None else None,
|
||||
}, save_path )
|
||||
|
||||
return state_dict
|
||||
|
||||
|
@ -81,6 +87,7 @@ def main():
|
|||
parser.add_argument("--module-only", action='store_true')
|
||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
||||
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
|
||||
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
if args.module_only:
|
||||
|
@ -95,7 +102,10 @@ def main():
|
|||
if args.hf and args.lora:
|
||||
raise Exception("Requesting more than one callback")
|
||||
|
||||
engines = load_engines()
|
||||
if args.dtype != "auto":
|
||||
cfg.trainer.weight_dtype = args.dtype
|
||||
|
||||
engines = load_engines(training=False)
|
||||
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -581,6 +581,18 @@ class Base(nn.Module):
|
|||
))
|
||||
|
||||
self.model = RetNetDecoder(RetNetConfig(**kwargs))
|
||||
|
||||
# do some funny stuff for LoRA training
|
||||
"""
|
||||
if self.gradient_checkpointing:
|
||||
def make_inputs_require_grads(module, input, output):
|
||||
for i, t in enumerate(input):
|
||||
if not isinstance(t, torch.Tensor):
|
||||
continue
|
||||
t.requires_grad_(True)
|
||||
|
||||
self.model.register_forward_hook(make_inputs_require_grads)
|
||||
"""
|
||||
elif self.arch_type == "retnet-hf":
|
||||
kwargs = dict(
|
||||
vocab_size=n_resp_tokens,
|
||||
|
@ -713,7 +725,7 @@ class Base(nn.Module):
|
|||
x = inputs
|
||||
m = mask.squeeze(-1).int()
|
||||
aux_loss = None
|
||||
|
||||
|
||||
# HF transformer derived model
|
||||
if self.arch_type in ["llama", "mistral", "mixtral"]:
|
||||
kwargs = dict(
|
||||
|
|
|
@ -148,6 +148,7 @@ class ParameterizedLoRA(nn.Module):
|
|||
def passes_policy( policy, name ):
|
||||
if policy is None:
|
||||
return True
|
||||
|
||||
if "exclude" in policy:
|
||||
for term in policy["exclude"]:
|
||||
if term in name:
|
||||
|
|
Loading…
Reference in New Issue
Block a user