maybe backported some weird fixes for LoRA loading from mrq/vall-e ?
This commit is contained in:
parent
90ecf3da7d
commit
f25e765682
|
@ -218,6 +218,7 @@ class LoRA:
|
|||
rank: int = 8 # rank for the LoRA
|
||||
alpha: int = 16 # rank for the LoRA
|
||||
training: bool = True #
|
||||
embeddings: bool = False
|
||||
parametrize: bool = False #
|
||||
module: str = "linear" # linear | conv1d
|
||||
|
||||
|
|
|
@ -10,9 +10,9 @@ elif cfg.trainer.backend == "local":
|
|||
|
||||
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
||||
|
||||
from ..models import get_models
|
||||
from ..models import get_models, get_model
|
||||
from ..utils import wrapper as ml
|
||||
from ..models.lora import apply_lora
|
||||
from ..models.lora import apply_lora, lora_load_state_dict
|
||||
|
||||
import torch
|
||||
import re
|
||||
|
@ -32,23 +32,53 @@ def load_engines(training=True):
|
|||
engines = dict()
|
||||
|
||||
for name, model in models.items():
|
||||
state = None
|
||||
stats = None
|
||||
lora = None
|
||||
|
||||
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
|
||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||
loads_state_dict = cfg.trainer.load_state_dict # or inferencing
|
||||
|
||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
|
||||
# actually use the lora-specific checkpoint if available
|
||||
if cfg.lora is not None:
|
||||
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest"
|
||||
|
||||
# to handle the issue of training with deepspeed, but inferencing with local
|
||||
if checkpoint_path.exists() and backend == "local":
|
||||
tag = open(checkpoint_path).read()
|
||||
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / tag / "state.pth"
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
print("Checkpoint missing, but weights found.")
|
||||
loads_state_dict = True
|
||||
|
||||
# load state early
|
||||
if loads_state_dict:
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
|
||||
# check if config is defined in state, and re-initialize the model
|
||||
if "config" in state and False:
|
||||
print("Model config definition in weights, re-loading...")
|
||||
config_state = state["config"]
|
||||
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
|
||||
|
||||
hyper_config = model.config
|
||||
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
inferencing = cfg.mode == "inferencing" or not model.config.training
|
||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
||||
ddp = cfg.trainer.ddp
|
||||
|
||||
engine_class = LocalEngine if backend == "local" or inferencing else Engine
|
||||
|
||||
if inferencing:
|
||||
model.config.training = False
|
||||
engine_class = LocalEngine if backend == "local" else Engine
|
||||
|
||||
# apply model replacers
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
model.model = ml.replace_linear( model.model )
|
||||
|
||||
|
@ -60,6 +90,9 @@ def load_engines(training=True):
|
|||
#model.gpt = apply_lora( model.gpt, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, parametrize = lora.parametrize )
|
||||
model = apply_lora( model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
|
||||
|
||||
if inferencing:
|
||||
model.config.training = False
|
||||
|
||||
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
|
||||
optimizer_class = None
|
||||
scheduler_class = None
|
||||
|
@ -118,28 +151,14 @@ def load_engines(training=True):
|
|||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
|
||||
# actually use the lora-specific checkpoint if available
|
||||
if cfg.lora is not None:
|
||||
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
print("Checkpoint missing, but weights found.")
|
||||
loads_state_dict = True
|
||||
|
||||
stats = None
|
||||
if loads_state_dict and load_path.exists():
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
|
||||
# load state dict if requested / required
|
||||
if loads_state_dict:
|
||||
# state dict is not just the module, extract the extra trainer details
|
||||
if "stats" in state:
|
||||
stats = state["stats"]
|
||||
|
||||
# do not load stats if we're training a LoRA
|
||||
if "lora" not in state:
|
||||
if cfg.lora is not None or cfg.trainer.restart_step_count:
|
||||
stats = None
|
||||
|
||||
if "module" in state:
|
||||
|
@ -161,23 +180,23 @@ def load_engines(training=True):
|
|||
for k in erase:
|
||||
del state[k]
|
||||
|
||||
# resize text embedding
|
||||
if "text_emb.weight" in state and model.config.text_tokens != state["text_emb.weight"].shape[0]:
|
||||
state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens]
|
||||
|
||||
# resize text embedding
|
||||
if "rvq_l_emb.weight" in state and model.config.resp_levels != state["rvq_l_emb.weight"].shape[0]:
|
||||
state["rvq_l_emb.weight"] = state["rvq_l_emb.weight"][:model.config.resp_levels]
|
||||
# resize embeddings
|
||||
if "text_emb.weight" in state:
|
||||
state["text_emb.weight"] = ml.resize_weight( state["text_emb.weight"], model.config.text_tokens )
|
||||
if "rvq_l_emb.weight" in state:
|
||||
state["rvq_l_emb.weight"] = ml.resize_weight( state["rvq_l_emb.weight"], model.config.resp_levels )
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
# load lora weights if exists
|
||||
if cfg.lora is not None:
|
||||
lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth"
|
||||
if lora_path.exists():
|
||||
state = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||
state = state['lora' if 'lora' in state else 'module']
|
||||
model.load_state_dict(state, strict=False)
|
||||
# load lora weights if exists
|
||||
if cfg.lora is not None:
|
||||
lora_path = cfg.ckpt_dir / cfg.lora.full_name / "lora.pth"
|
||||
if lora_path.exists():
|
||||
print( "Loaded LoRA state dict:", lora_path )
|
||||
|
||||
state = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||
state = state['lora' if 'lora' in state else 'module']
|
||||
lora_load_state_dict( model, state )
|
||||
|
||||
# wrap if DDP is requested
|
||||
if ddp:
|
||||
|
@ -202,42 +221,12 @@ def load_engines(training=True):
|
|||
engines = Engines(engines)
|
||||
engines.setup()
|
||||
|
||||
# this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not
|
||||
if not cfg.trainer.load_state_dict:
|
||||
engines.load_checkpoint()
|
||||
engines.load_checkpoint(training=not inferencing)
|
||||
|
||||
# freeze requested params
|
||||
for name, engine in engines.items():
|
||||
engine.freeze(freeze_all=False)
|
||||
|
||||
"""
|
||||
# copy embeddings if requested
|
||||
if cfg.model._embeddings is not None:
|
||||
embeddings_path = cfg.rel_path / cfg.model._embeddings
|
||||
|
||||
if embeddings_path.exists():
|
||||
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
|
||||
if "module" in embeddings:
|
||||
embeddings = embeddings["module"]
|
||||
|
||||
frozen_params = set()
|
||||
|
||||
for k in list(embeddings.keys()):
|
||||
if re.findall(r'_emb\.', k):
|
||||
frozen_params.add(k)
|
||||
else:
|
||||
del embeddings[k]
|
||||
|
||||
engine.module.load_state_dict(embeddings, strict=False)
|
||||
|
||||
# there's definitely a much better way but I can't be assed at the moment
|
||||
for name, param in engine.module.named_parameters():
|
||||
if name not in frozen_params:
|
||||
continue
|
||||
param.requires_grad_(False)
|
||||
engine._frozen_params.add(param)
|
||||
"""
|
||||
|
||||
|
||||
#do_gc()
|
||||
|
||||
return engines
|
|
@ -81,7 +81,7 @@ class Engine():
|
|||
|
||||
# freeze non-LoRA params if requested
|
||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||
return freeze_non_lora_weights( self.module )
|
||||
return freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
|
||||
|
||||
for name, param in self.module.named_parameters():
|
||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
||||
|
@ -164,15 +164,19 @@ class Engine():
|
|||
|
||||
if tag is None:
|
||||
tag_path = load_dir / "latest"
|
||||
|
||||
if not tag_path.exists():
|
||||
return
|
||||
|
||||
tag = open(tag_path).read()
|
||||
|
||||
load_path = load_dir / tag / "state.pth"
|
||||
|
||||
if not load_path.exists():
|
||||
return
|
||||
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
|
||||
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
|
||||
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
||||
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
||||
|
@ -320,11 +324,21 @@ class Engines(dict[str, Engine]):
|
|||
for engine in self.values():
|
||||
engine.dispatch_attribute(*args, **kwargs)
|
||||
|
||||
def export(self, userdata={}, callback=None):
|
||||
def export(self, userdata={}, callback=None, dtype=None):
|
||||
if dtype is None:
|
||||
dtype = cfg.trainer.dtype
|
||||
|
||||
for name, engine in self.items():
|
||||
module = engine.module.state_dict()
|
||||
lora = None
|
||||
save_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
|
||||
if not isinstance(config, dict):
|
||||
config = config.__dict__
|
||||
|
||||
# safety
|
||||
for k, v in module.items():
|
||||
module[k] = v.to(dtype)
|
||||
|
||||
if cfg.lora is not None:
|
||||
lora, module = lora_get_state_dict( module, split = True )
|
||||
|
@ -339,8 +353,13 @@ class Engines(dict[str, Engine]):
|
|||
"global_samples": engine.global_samples,
|
||||
"tokens_processed": engine.tokens_processed,
|
||||
},
|
||||
"userdata": userdata
|
||||
"userdata": userdata,
|
||||
"config": config
|
||||
}
|
||||
|
||||
if lora is None:
|
||||
del state_dict['lora']
|
||||
|
||||
if callback:
|
||||
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
||||
|
||||
|
@ -378,18 +397,19 @@ class Engines(dict[str, Engine]):
|
|||
p.unlink()
|
||||
d.rmdir()
|
||||
|
||||
def load_checkpoint(self, tag=None):
|
||||
def load_checkpoint(self, tag=None, training=True):
|
||||
if not tag:
|
||||
tag = cfg.trainer.load_tag
|
||||
|
||||
for name, engine in self.items():
|
||||
load_dir = cfg.ckpt_dir / name
|
||||
|
||||
engine.load_checkpoint(
|
||||
tag=tag,
|
||||
load_dir=load_dir,
|
||||
load_module_strict=cfg.trainer.strict_loading,
|
||||
load_optimizer_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
|
||||
load_lr_scheduler_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
|
||||
load_optimizer_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
|
||||
load_lr_scheduler_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
|
||||
load_module_only=cfg.trainer.load_module_only,
|
||||
)
|
||||
if cfg.trainer.restart_step_count:
|
||||
|
|
|
@ -27,6 +27,8 @@ from deepspeed.accelerator import get_accelerator
|
|||
from ..utils.distributed import init_distributed, distributed_initialized
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
from ..models.lora import freeze_non_lora_weights
|
||||
|
||||
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
|
||||
init_distributed(init_deepspeed_dist)
|
||||
|
||||
|
@ -66,11 +68,10 @@ class Engine(DeepSpeedEngine):
|
|||
def freeze(self, freeze_all=True):
|
||||
# freeze non-LoRA params if requested
|
||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||
for name, param in self.module.named_parameters():
|
||||
should = 'lora_' in name
|
||||
param.requires_grad_(should)
|
||||
if not should:
|
||||
self._frozen_params.add(param)
|
||||
frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
|
||||
for param in frozen_params:
|
||||
self._frozen_params.add( param )
|
||||
|
||||
return
|
||||
|
||||
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
||||
|
|
|
@ -8,7 +8,10 @@ from .engines import load_engines
|
|||
from .config import cfg
|
||||
from .models.lora import lora_get_state_dict
|
||||
|
||||
def extract_lora( state_dict, config = None, save_path = None ):
|
||||
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
||||
if dtype is None:
|
||||
dtype = cfg.inference.dtype
|
||||
|
||||
lora = state_dict["lora"] if "lora" in state_dict else None
|
||||
# should always be included, but just in case
|
||||
if lora is None and "module" in state_dict:
|
||||
|
@ -23,15 +26,18 @@ def extract_lora( state_dict, config = None, save_path = None ):
|
|||
# save lora specifically
|
||||
# should probably export other attributes, similar to what SD LoRAs do
|
||||
save_path = save_path.parent / "lora.pth"
|
||||
torch.save( { "module": lora }, save_path )
|
||||
torch.save( {
|
||||
"module": lora,
|
||||
"config": cfg.lora.__dict__ if cfg.lora is not None else None,
|
||||
}, save_path )
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--module-only", action='store_true')
|
||||
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
|
||||
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
if args.module_only:
|
||||
|
@ -41,7 +47,10 @@ def main():
|
|||
if args.lora:
|
||||
callback = extract_lora
|
||||
|
||||
engines = load_engines()
|
||||
if args.dtype != "auto":
|
||||
cfg.trainer.weight_dtype = args.dtype
|
||||
|
||||
engines = load_engines(training=False)
|
||||
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
||||
|
||||
from functools import partial
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -148,6 +147,7 @@ class ParameterizedLoRA(nn.Module):
|
|||
def passes_policy( policy, name ):
|
||||
if policy is None:
|
||||
return True
|
||||
|
||||
if "exclude" in policy:
|
||||
for term in policy["exclude"]:
|
||||
if term in name:
|
||||
|
@ -192,7 +192,7 @@ def apply_lora( model, register = True, merge = False, policy = None, use_parame
|
|||
else:
|
||||
setattr( model.get_submodule(name), k, replacement )
|
||||
|
||||
return model
|
||||
return enable_lora( model )
|
||||
|
||||
def enable_lora( model, mode = True ):
|
||||
for name, module in model.named_modules():
|
||||
|
@ -204,10 +204,18 @@ def enable_lora( model, mode = True ):
|
|||
def disable_lora( model ):
|
||||
return enable_lora( model, False )
|
||||
|
||||
def freeze_non_lora_weights( model ):
|
||||
def freeze_non_lora_weights( model, embeddings = False ):
|
||||
frozen_params = []
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
param.requires_grad_('lora_' in name)
|
||||
return model
|
||||
should = 'lora_' in name or (embeddings and "_emb" in name)
|
||||
|
||||
param.requires_grad_(should)
|
||||
|
||||
if not should:
|
||||
frozen_params.append( param )
|
||||
|
||||
return frozen_params
|
||||
|
||||
def lora_get_state_dict( state_dict, split = True ):
|
||||
lora = { name: param for name, param in state_dict.items() if "lora_" in name }
|
||||
|
|
Loading…
Reference in New Issue
Block a user