maybe backported some weird fixes for LoRA loading from mrq/vall-e ?
This commit is contained in:
parent
90ecf3da7d
commit
f25e765682
|
@ -218,6 +218,7 @@ class LoRA:
|
||||||
rank: int = 8 # rank for the LoRA
|
rank: int = 8 # rank for the LoRA
|
||||||
alpha: int = 16 # rank for the LoRA
|
alpha: int = 16 # rank for the LoRA
|
||||||
training: bool = True #
|
training: bool = True #
|
||||||
|
embeddings: bool = False
|
||||||
parametrize: bool = False #
|
parametrize: bool = False #
|
||||||
module: str = "linear" # linear | conv1d
|
module: str = "linear" # linear | conv1d
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,9 @@ elif cfg.trainer.backend == "local":
|
||||||
|
|
||||||
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
||||||
|
|
||||||
from ..models import get_models
|
from ..models import get_models, get_model
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
from ..models.lora import apply_lora
|
from ..models.lora import apply_lora, lora_load_state_dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
@ -32,23 +32,53 @@ def load_engines(training=True):
|
||||||
engines = dict()
|
engines = dict()
|
||||||
|
|
||||||
for name, model in models.items():
|
for name, model in models.items():
|
||||||
|
state = None
|
||||||
|
stats = None
|
||||||
|
lora = None
|
||||||
|
|
||||||
|
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
|
||||||
|
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||||
|
loads_state_dict = cfg.trainer.load_state_dict # or inferencing
|
||||||
|
|
||||||
|
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||||
|
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||||
|
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
|
|
||||||
|
# actually use the lora-specific checkpoint if available
|
||||||
|
if cfg.lora is not None:
|
||||||
|
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest"
|
||||||
|
|
||||||
|
# to handle the issue of training with deepspeed, but inferencing with local
|
||||||
|
if checkpoint_path.exists() and backend == "local":
|
||||||
|
tag = open(checkpoint_path).read()
|
||||||
|
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / tag / "state.pth"
|
||||||
|
|
||||||
|
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||||
|
print("Checkpoint missing, but weights found.")
|
||||||
|
loads_state_dict = True
|
||||||
|
|
||||||
|
# load state early
|
||||||
|
if loads_state_dict:
|
||||||
|
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||||
|
|
||||||
|
# check if config is defined in state, and re-initialize the model
|
||||||
|
if "config" in state and False:
|
||||||
|
print("Model config definition in weights, re-loading...")
|
||||||
|
config_state = state["config"]
|
||||||
|
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
|
||||||
|
|
||||||
hyper_config = model.config
|
hyper_config = model.config
|
||||||
|
|
||||||
optimizer = None
|
optimizer = None
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
inferencing = cfg.mode == "inferencing" or not model.config.training
|
|
||||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
|
||||||
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||||
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||||
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
|
||||||
ddp = cfg.trainer.ddp
|
ddp = cfg.trainer.ddp
|
||||||
|
|
||||||
engine_class = LocalEngine if backend == "local" or inferencing else Engine
|
engine_class = LocalEngine if backend == "local" else Engine
|
||||||
|
|
||||||
if inferencing:
|
|
||||||
model.config.training = False
|
|
||||||
|
|
||||||
|
# apply model replacers
|
||||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||||
model.model = ml.replace_linear( model.model )
|
model.model = ml.replace_linear( model.model )
|
||||||
|
|
||||||
|
@ -60,6 +90,9 @@ def load_engines(training=True):
|
||||||
#model.gpt = apply_lora( model.gpt, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, parametrize = lora.parametrize )
|
#model.gpt = apply_lora( model.gpt, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, parametrize = lora.parametrize )
|
||||||
model = apply_lora( model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
|
model = apply_lora( model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
|
||||||
|
|
||||||
|
if inferencing:
|
||||||
|
model.config.training = False
|
||||||
|
|
||||||
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
|
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
|
||||||
optimizer_class = None
|
optimizer_class = None
|
||||||
scheduler_class = None
|
scheduler_class = None
|
||||||
|
@ -118,28 +151,14 @@ def load_engines(training=True):
|
||||||
optimizer = None
|
optimizer = None
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
# load state dict if requested / required
|
||||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
if loads_state_dict:
|
||||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
|
||||||
|
|
||||||
# actually use the lora-specific checkpoint if available
|
|
||||||
if cfg.lora is not None:
|
|
||||||
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
|
|
||||||
|
|
||||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
|
||||||
print("Checkpoint missing, but weights found.")
|
|
||||||
loads_state_dict = True
|
|
||||||
|
|
||||||
stats = None
|
|
||||||
if loads_state_dict and load_path.exists():
|
|
||||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
|
||||||
|
|
||||||
# state dict is not just the module, extract the extra trainer details
|
# state dict is not just the module, extract the extra trainer details
|
||||||
if "stats" in state:
|
if "stats" in state:
|
||||||
stats = state["stats"]
|
stats = state["stats"]
|
||||||
|
|
||||||
# do not load stats if we're training a LoRA
|
# do not load stats if we're training a LoRA
|
||||||
if "lora" not in state:
|
if cfg.lora is not None or cfg.trainer.restart_step_count:
|
||||||
stats = None
|
stats = None
|
||||||
|
|
||||||
if "module" in state:
|
if "module" in state:
|
||||||
|
@ -161,23 +180,23 @@ def load_engines(training=True):
|
||||||
for k in erase:
|
for k in erase:
|
||||||
del state[k]
|
del state[k]
|
||||||
|
|
||||||
# resize text embedding
|
# resize embeddings
|
||||||
if "text_emb.weight" in state and model.config.text_tokens != state["text_emb.weight"].shape[0]:
|
if "text_emb.weight" in state:
|
||||||
state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens]
|
state["text_emb.weight"] = ml.resize_weight( state["text_emb.weight"], model.config.text_tokens )
|
||||||
|
if "rvq_l_emb.weight" in state:
|
||||||
# resize text embedding
|
state["rvq_l_emb.weight"] = ml.resize_weight( state["rvq_l_emb.weight"], model.config.resp_levels )
|
||||||
if "rvq_l_emb.weight" in state and model.config.resp_levels != state["rvq_l_emb.weight"].shape[0]:
|
|
||||||
state["rvq_l_emb.weight"] = state["rvq_l_emb.weight"][:model.config.resp_levels]
|
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
# load lora weights if exists
|
# load lora weights if exists
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth"
|
lora_path = cfg.ckpt_dir / cfg.lora.full_name / "lora.pth"
|
||||||
if lora_path.exists():
|
if lora_path.exists():
|
||||||
state = torch.load(lora_path, map_location=torch.device(cfg.device))
|
print( "Loaded LoRA state dict:", lora_path )
|
||||||
state = state['lora' if 'lora' in state else 'module']
|
|
||||||
model.load_state_dict(state, strict=False)
|
state = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||||
|
state = state['lora' if 'lora' in state else 'module']
|
||||||
|
lora_load_state_dict( model, state )
|
||||||
|
|
||||||
# wrap if DDP is requested
|
# wrap if DDP is requested
|
||||||
if ddp:
|
if ddp:
|
||||||
|
@ -202,42 +221,12 @@ def load_engines(training=True):
|
||||||
engines = Engines(engines)
|
engines = Engines(engines)
|
||||||
engines.setup()
|
engines.setup()
|
||||||
|
|
||||||
|
# this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not
|
||||||
if not cfg.trainer.load_state_dict:
|
if not cfg.trainer.load_state_dict:
|
||||||
engines.load_checkpoint()
|
engines.load_checkpoint(training=not inferencing)
|
||||||
|
|
||||||
# freeze requested params
|
# freeze requested params
|
||||||
for name, engine in engines.items():
|
for name, engine in engines.items():
|
||||||
engine.freeze(freeze_all=False)
|
engine.freeze(freeze_all=False)
|
||||||
|
|
||||||
"""
|
|
||||||
# copy embeddings if requested
|
|
||||||
if cfg.model._embeddings is not None:
|
|
||||||
embeddings_path = cfg.rel_path / cfg.model._embeddings
|
|
||||||
|
|
||||||
if embeddings_path.exists():
|
|
||||||
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
|
|
||||||
if "module" in embeddings:
|
|
||||||
embeddings = embeddings["module"]
|
|
||||||
|
|
||||||
frozen_params = set()
|
|
||||||
|
|
||||||
for k in list(embeddings.keys()):
|
|
||||||
if re.findall(r'_emb\.', k):
|
|
||||||
frozen_params.add(k)
|
|
||||||
else:
|
|
||||||
del embeddings[k]
|
|
||||||
|
|
||||||
engine.module.load_state_dict(embeddings, strict=False)
|
|
||||||
|
|
||||||
# there's definitely a much better way but I can't be assed at the moment
|
|
||||||
for name, param in engine.module.named_parameters():
|
|
||||||
if name not in frozen_params:
|
|
||||||
continue
|
|
||||||
param.requires_grad_(False)
|
|
||||||
engine._frozen_params.add(param)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
#do_gc()
|
|
||||||
|
|
||||||
return engines
|
return engines
|
|
@ -81,7 +81,7 @@ class Engine():
|
||||||
|
|
||||||
# freeze non-LoRA params if requested
|
# freeze non-LoRA params if requested
|
||||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||||
return freeze_non_lora_weights( self.module )
|
return freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
|
||||||
|
|
||||||
for name, param in self.module.named_parameters():
|
for name, param in self.module.named_parameters():
|
||||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
||||||
|
@ -164,15 +164,19 @@ class Engine():
|
||||||
|
|
||||||
if tag is None:
|
if tag is None:
|
||||||
tag_path = load_dir / "latest"
|
tag_path = load_dir / "latest"
|
||||||
|
|
||||||
if not tag_path.exists():
|
if not tag_path.exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
tag = open(tag_path).read()
|
tag = open(tag_path).read()
|
||||||
|
|
||||||
load_path = load_dir / tag / "state.pth"
|
load_path = load_dir / tag / "state.pth"
|
||||||
|
|
||||||
if not load_path.exists():
|
if not load_path.exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||||
|
|
||||||
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
|
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
|
||||||
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
||||||
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
||||||
|
@ -320,11 +324,21 @@ class Engines(dict[str, Engine]):
|
||||||
for engine in self.values():
|
for engine in self.values():
|
||||||
engine.dispatch_attribute(*args, **kwargs)
|
engine.dispatch_attribute(*args, **kwargs)
|
||||||
|
|
||||||
def export(self, userdata={}, callback=None):
|
def export(self, userdata={}, callback=None, dtype=None):
|
||||||
|
if dtype is None:
|
||||||
|
dtype = cfg.trainer.dtype
|
||||||
|
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
module = engine.module.state_dict()
|
module = engine.module.state_dict()
|
||||||
lora = None
|
lora = None
|
||||||
save_path = cfg.ckpt_dir / name / "fp32.pth"
|
save_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
|
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
config = config.__dict__
|
||||||
|
|
||||||
|
# safety
|
||||||
|
for k, v in module.items():
|
||||||
|
module[k] = v.to(dtype)
|
||||||
|
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
lora, module = lora_get_state_dict( module, split = True )
|
lora, module = lora_get_state_dict( module, split = True )
|
||||||
|
@ -339,8 +353,13 @@ class Engines(dict[str, Engine]):
|
||||||
"global_samples": engine.global_samples,
|
"global_samples": engine.global_samples,
|
||||||
"tokens_processed": engine.tokens_processed,
|
"tokens_processed": engine.tokens_processed,
|
||||||
},
|
},
|
||||||
"userdata": userdata
|
"userdata": userdata,
|
||||||
|
"config": config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if lora is None:
|
||||||
|
del state_dict['lora']
|
||||||
|
|
||||||
if callback:
|
if callback:
|
||||||
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
||||||
|
|
||||||
|
@ -378,18 +397,19 @@ class Engines(dict[str, Engine]):
|
||||||
p.unlink()
|
p.unlink()
|
||||||
d.rmdir()
|
d.rmdir()
|
||||||
|
|
||||||
def load_checkpoint(self, tag=None):
|
def load_checkpoint(self, tag=None, training=True):
|
||||||
if not tag:
|
if not tag:
|
||||||
tag = cfg.trainer.load_tag
|
tag = cfg.trainer.load_tag
|
||||||
|
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
load_dir = cfg.ckpt_dir / name
|
load_dir = cfg.ckpt_dir / name
|
||||||
|
|
||||||
engine.load_checkpoint(
|
engine.load_checkpoint(
|
||||||
tag=tag,
|
tag=tag,
|
||||||
load_dir=load_dir,
|
load_dir=load_dir,
|
||||||
load_module_strict=cfg.trainer.strict_loading,
|
load_module_strict=cfg.trainer.strict_loading,
|
||||||
load_optimizer_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
|
load_optimizer_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
|
||||||
load_lr_scheduler_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
|
load_lr_scheduler_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
|
||||||
load_module_only=cfg.trainer.load_module_only,
|
load_module_only=cfg.trainer.load_module_only,
|
||||||
)
|
)
|
||||||
if cfg.trainer.restart_step_count:
|
if cfg.trainer.restart_step_count:
|
||||||
|
|
|
@ -27,6 +27,8 @@ from deepspeed.accelerator import get_accelerator
|
||||||
from ..utils.distributed import init_distributed, distributed_initialized
|
from ..utils.distributed import init_distributed, distributed_initialized
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
|
|
||||||
|
from ..models.lora import freeze_non_lora_weights
|
||||||
|
|
||||||
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
|
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
|
||||||
init_distributed(init_deepspeed_dist)
|
init_distributed(init_deepspeed_dist)
|
||||||
|
|
||||||
|
@ -66,11 +68,10 @@ class Engine(DeepSpeedEngine):
|
||||||
def freeze(self, freeze_all=True):
|
def freeze(self, freeze_all=True):
|
||||||
# freeze non-LoRA params if requested
|
# freeze non-LoRA params if requested
|
||||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||||
for name, param in self.module.named_parameters():
|
frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
|
||||||
should = 'lora_' in name
|
for param in frozen_params:
|
||||||
param.requires_grad_(should)
|
self._frozen_params.add( param )
|
||||||
if not should:
|
|
||||||
self._frozen_params.add(param)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
||||||
|
|
|
@ -8,7 +8,10 @@ from .engines import load_engines
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .models.lora import lora_get_state_dict
|
from .models.lora import lora_get_state_dict
|
||||||
|
|
||||||
def extract_lora( state_dict, config = None, save_path = None ):
|
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
||||||
|
if dtype is None:
|
||||||
|
dtype = cfg.inference.dtype
|
||||||
|
|
||||||
lora = state_dict["lora"] if "lora" in state_dict else None
|
lora = state_dict["lora"] if "lora" in state_dict else None
|
||||||
# should always be included, but just in case
|
# should always be included, but just in case
|
||||||
if lora is None and "module" in state_dict:
|
if lora is None and "module" in state_dict:
|
||||||
|
@ -23,15 +26,18 @@ def extract_lora( state_dict, config = None, save_path = None ):
|
||||||
# save lora specifically
|
# save lora specifically
|
||||||
# should probably export other attributes, similar to what SD LoRAs do
|
# should probably export other attributes, similar to what SD LoRAs do
|
||||||
save_path = save_path.parent / "lora.pth"
|
save_path = save_path.parent / "lora.pth"
|
||||||
torch.save( { "module": lora }, save_path )
|
torch.save( {
|
||||||
|
"module": lora,
|
||||||
|
"config": cfg.lora.__dict__ if cfg.lora is not None else None,
|
||||||
|
}, save_path )
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||||
parser.add_argument("--module-only", action='store_true')
|
parser.add_argument("--module-only", action='store_true')
|
||||||
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
|
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
|
||||||
|
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
if args.module_only:
|
if args.module_only:
|
||||||
|
@ -41,7 +47,10 @@ def main():
|
||||||
if args.lora:
|
if args.lora:
|
||||||
callback = extract_lora
|
callback = extract_lora
|
||||||
|
|
||||||
engines = load_engines()
|
if args.dtype != "auto":
|
||||||
|
cfg.trainer.weight_dtype = args.dtype
|
||||||
|
|
||||||
|
engines = load_engines(training=False)
|
||||||
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -148,6 +147,7 @@ class ParameterizedLoRA(nn.Module):
|
||||||
def passes_policy( policy, name ):
|
def passes_policy( policy, name ):
|
||||||
if policy is None:
|
if policy is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if "exclude" in policy:
|
if "exclude" in policy:
|
||||||
for term in policy["exclude"]:
|
for term in policy["exclude"]:
|
||||||
if term in name:
|
if term in name:
|
||||||
|
@ -192,7 +192,7 @@ def apply_lora( model, register = True, merge = False, policy = None, use_parame
|
||||||
else:
|
else:
|
||||||
setattr( model.get_submodule(name), k, replacement )
|
setattr( model.get_submodule(name), k, replacement )
|
||||||
|
|
||||||
return model
|
return enable_lora( model )
|
||||||
|
|
||||||
def enable_lora( model, mode = True ):
|
def enable_lora( model, mode = True ):
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
@ -204,10 +204,18 @@ def enable_lora( model, mode = True ):
|
||||||
def disable_lora( model ):
|
def disable_lora( model ):
|
||||||
return enable_lora( model, False )
|
return enable_lora( model, False )
|
||||||
|
|
||||||
def freeze_non_lora_weights( model ):
|
def freeze_non_lora_weights( model, embeddings = False ):
|
||||||
|
frozen_params = []
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
param.requires_grad_('lora_' in name)
|
should = 'lora_' in name or (embeddings and "_emb" in name)
|
||||||
return model
|
|
||||||
|
param.requires_grad_(should)
|
||||||
|
|
||||||
|
if not should:
|
||||||
|
frozen_params.append( param )
|
||||||
|
|
||||||
|
return frozen_params
|
||||||
|
|
||||||
def lora_get_state_dict( state_dict, split = True ):
|
def lora_get_state_dict( state_dict, split = True ):
|
||||||
lora = { name: param for name, param in state_dict.items() if "lora_" in name }
|
lora = { name: param for name, param in state_dict.items() if "lora_" in name }
|
||||||
|
|
Loading…
Reference in New Issue
Block a user