maybe backported some weird fixes for LoRA loading from mrq/vall-e ?

This commit is contained in:
mrq 2024-07-22 20:48:06 -05:00
parent 90ecf3da7d
commit f25e765682
6 changed files with 120 additions and 92 deletions

View File

@ -218,6 +218,7 @@ class LoRA:
rank: int = 8 # rank for the LoRA
alpha: int = 16 # rank for the LoRA
training: bool = True #
embeddings: bool = False
parametrize: bool = False #
module: str = "linear" # linear | conv1d

View File

@ -10,9 +10,9 @@ elif cfg.trainer.backend == "local":
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
from ..models import get_models
from ..models import get_models, get_model
from ..utils import wrapper as ml
from ..models.lora import apply_lora
from ..models.lora import apply_lora, lora_load_state_dict
import torch
import re
@ -32,23 +32,53 @@ def load_engines(training=True):
engines = dict()
for name, model in models.items():
state = None
stats = None
lora = None
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
loads_state_dict = cfg.trainer.load_state_dict # or inferencing
checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = cfg.ckpt_dir / name / "fp32.pth"
# actually use the lora-specific checkpoint if available
if cfg.lora is not None:
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest"
# to handle the issue of training with deepspeed, but inferencing with local
if checkpoint_path.exists() and backend == "local":
tag = open(checkpoint_path).read()
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / tag / "state.pth"
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found.")
loads_state_dict = True
# load state early
if loads_state_dict:
state = torch.load(load_path, map_location=torch.device(cfg.device))
# check if config is defined in state, and re-initialize the model
if "config" in state and False:
print("Model config definition in weights, re-loading...")
config_state = state["config"]
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
hyper_config = model.config
optimizer = None
lr_scheduler = None
inferencing = cfg.mode == "inferencing" or not model.config.training
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
loads_state_dict = cfg.trainer.load_state_dict or inferencing
ddp = cfg.trainer.ddp
engine_class = LocalEngine if backend == "local" or inferencing else Engine
if inferencing:
model.config.training = False
engine_class = LocalEngine if backend == "local" else Engine
# apply model replacers
if cfg.optimizations.replace and cfg.optimizations.linear:
model.model = ml.replace_linear( model.model )
@ -60,6 +90,9 @@ def load_engines(training=True):
#model.gpt = apply_lora( model.gpt, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, parametrize = lora.parametrize )
model = apply_lora( model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
if inferencing:
model.config.training = False
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
optimizer_class = None
scheduler_class = None
@ -118,28 +151,14 @@ def load_engines(training=True):
optimizer = None
lr_scheduler = None
checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = cfg.ckpt_dir / name / "fp32.pth"
# actually use the lora-specific checkpoint if available
if cfg.lora is not None:
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found.")
loads_state_dict = True
stats = None
if loads_state_dict and load_path.exists():
state = torch.load(load_path, map_location=torch.device(cfg.device))
# load state dict if requested / required
if loads_state_dict:
# state dict is not just the module, extract the extra trainer details
if "stats" in state:
stats = state["stats"]
# do not load stats if we're training a LoRA
if "lora" not in state:
if cfg.lora is not None or cfg.trainer.restart_step_count:
stats = None
if "module" in state:
@ -161,23 +180,23 @@ def load_engines(training=True):
for k in erase:
del state[k]
# resize text embedding
if "text_emb.weight" in state and model.config.text_tokens != state["text_emb.weight"].shape[0]:
state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens]
# resize text embedding
if "rvq_l_emb.weight" in state and model.config.resp_levels != state["rvq_l_emb.weight"].shape[0]:
state["rvq_l_emb.weight"] = state["rvq_l_emb.weight"][:model.config.resp_levels]
# resize embeddings
if "text_emb.weight" in state:
state["text_emb.weight"] = ml.resize_weight( state["text_emb.weight"], model.config.text_tokens )
if "rvq_l_emb.weight" in state:
state["rvq_l_emb.weight"] = ml.resize_weight( state["rvq_l_emb.weight"], model.config.resp_levels )
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# load lora weights if exists
if cfg.lora is not None:
lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth"
if lora_path.exists():
state = torch.load(lora_path, map_location=torch.device(cfg.device))
state = state['lora' if 'lora' in state else 'module']
model.load_state_dict(state, strict=False)
# load lora weights if exists
if cfg.lora is not None:
lora_path = cfg.ckpt_dir / cfg.lora.full_name / "lora.pth"
if lora_path.exists():
print( "Loaded LoRA state dict:", lora_path )
state = torch.load(lora_path, map_location=torch.device(cfg.device))
state = state['lora' if 'lora' in state else 'module']
lora_load_state_dict( model, state )
# wrap if DDP is requested
if ddp:
@ -202,42 +221,12 @@ def load_engines(training=True):
engines = Engines(engines)
engines.setup()
# this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not
if not cfg.trainer.load_state_dict:
engines.load_checkpoint()
engines.load_checkpoint(training=not inferencing)
# freeze requested params
for name, engine in engines.items():
engine.freeze(freeze_all=False)
"""
# copy embeddings if requested
if cfg.model._embeddings is not None:
embeddings_path = cfg.rel_path / cfg.model._embeddings
if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
if "module" in embeddings:
embeddings = embeddings["module"]
frozen_params = set()
for k in list(embeddings.keys()):
if re.findall(r'_emb\.', k):
frozen_params.add(k)
else:
del embeddings[k]
engine.module.load_state_dict(embeddings, strict=False)
# there's definitely a much better way but I can't be assed at the moment
for name, param in engine.module.named_parameters():
if name not in frozen_params:
continue
param.requires_grad_(False)
engine._frozen_params.add(param)
"""
#do_gc()
return engines

View File

@ -81,7 +81,7 @@ class Engine():
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
return freeze_non_lora_weights( self.module )
return freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
@ -164,15 +164,19 @@ class Engine():
if tag is None:
tag_path = load_dir / "latest"
if not tag_path.exists():
return
tag = open(tag_path).read()
load_path = load_dir / tag / "state.pth"
if not load_path.exists():
return
state = torch.load(load_path, map_location=torch.device(cfg.device))
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
@ -320,11 +324,21 @@ class Engines(dict[str, Engine]):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def export(self, userdata={}, callback=None):
def export(self, userdata={}, callback=None, dtype=None):
if dtype is None:
dtype = cfg.trainer.dtype
for name, engine in self.items():
module = engine.module.state_dict()
lora = None
save_path = cfg.ckpt_dir / name / "fp32.pth"
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
if not isinstance(config, dict):
config = config.__dict__
# safety
for k, v in module.items():
module[k] = v.to(dtype)
if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True )
@ -339,8 +353,13 @@ class Engines(dict[str, Engine]):
"global_samples": engine.global_samples,
"tokens_processed": engine.tokens_processed,
},
"userdata": userdata
"userdata": userdata,
"config": config
}
if lora is None:
del state_dict['lora']
if callback:
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
@ -378,18 +397,19 @@ class Engines(dict[str, Engine]):
p.unlink()
d.rmdir()
def load_checkpoint(self, tag=None):
def load_checkpoint(self, tag=None, training=True):
if not tag:
tag = cfg.trainer.load_tag
for name, engine in self.items():
load_dir = cfg.ckpt_dir / name
engine.load_checkpoint(
tag=tag,
load_dir=load_dir,
load_module_strict=cfg.trainer.strict_loading,
load_optimizer_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
load_lr_scheduler_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
load_optimizer_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
load_lr_scheduler_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
load_module_only=cfg.trainer.load_module_only,
)
if cfg.trainer.restart_step_count:

View File

@ -27,6 +27,8 @@ from deepspeed.accelerator import get_accelerator
from ..utils.distributed import init_distributed, distributed_initialized
from ..utils import wrapper as ml
from ..models.lora import freeze_non_lora_weights
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
init_distributed(init_deepspeed_dist)
@ -66,11 +68,10 @@ class Engine(DeepSpeedEngine):
def freeze(self, freeze_all=True):
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
for name, param in self.module.named_parameters():
should = 'lora_' in name
param.requires_grad_(should)
if not should:
self._frozen_params.add(param)
frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
for param in frozen_params:
self._frozen_params.add( param )
return
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):

View File

@ -8,7 +8,10 @@ from .engines import load_engines
from .config import cfg
from .models.lora import lora_get_state_dict
def extract_lora( state_dict, config = None, save_path = None ):
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
if dtype is None:
dtype = cfg.inference.dtype
lora = state_dict["lora"] if "lora" in state_dict else None
# should always be included, but just in case
if lora is None and "module" in state_dict:
@ -23,15 +26,18 @@ def extract_lora( state_dict, config = None, save_path = None ):
# save lora specifically
# should probably export other attributes, similar to what SD LoRAs do
save_path = save_path.parent / "lora.pth"
torch.save( { "module": lora }, save_path )
torch.save( {
"module": lora,
"config": cfg.lora.__dict__ if cfg.lora is not None else None,
}, save_path )
return state_dict
def main():
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--module-only", action='store_true')
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
args, unknown = parser.parse_known_args()
if args.module_only:
@ -41,7 +47,10 @@ def main():
if args.lora:
callback = extract_lora
engines = load_engines()
if args.dtype != "auto":
cfg.trainer.weight_dtype = args.dtype
engines = load_engines(training=False)
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
if __name__ == "__main__":

View File

@ -1,5 +1,4 @@
# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
from functools import partial
import torch
import torch.nn.functional as F
@ -148,6 +147,7 @@ class ParameterizedLoRA(nn.Module):
def passes_policy( policy, name ):
if policy is None:
return True
if "exclude" in policy:
for term in policy["exclude"]:
if term in name:
@ -192,7 +192,7 @@ def apply_lora( model, register = True, merge = False, policy = None, use_parame
else:
setattr( model.get_submodule(name), k, replacement )
return model
return enable_lora( model )
def enable_lora( model, mode = True ):
for name, module in model.named_modules():
@ -204,10 +204,18 @@ def enable_lora( model, mode = True ):
def disable_lora( model ):
return enable_lora( model, False )
def freeze_non_lora_weights( model ):
def freeze_non_lora_weights( model, embeddings = False ):
frozen_params = []
for name, param in model.named_parameters():
param.requires_grad_('lora_' in name)
return model
should = 'lora_' in name or (embeddings and "_emb" in name)
param.requires_grad_(should)
if not should:
frozen_params.append( param )
return frozen_params
def lora_get_state_dict( state_dict, split = True ):
lora = { name: param for name, param in state_dict.items() if "lora_" in name }