some weird fixes for an equally weird regression with LoRA loading
This commit is contained in:
parent
e33c4b0cb1
commit
188d116222
|
@ -4,12 +4,14 @@ from .inference import TTS
|
|||
from .config import cfg
|
||||
|
||||
def path_list(arg):
|
||||
if not arg:
|
||||
return None
|
||||
return [Path(p) for p in arg.split(";")]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("VALL-E TTS")
|
||||
parser.add_argument("text")
|
||||
parser.add_argument("references", type=path_list)
|
||||
parser.add_argument("references", type=path_list, default=None)
|
||||
parser.add_argument("--language", type=str, default="en")
|
||||
parser.add_argument("--out-path", type=Path, default=None)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
|||
|
||||
from ..models import get_models, get_model
|
||||
from ..utils import wrapper as ml
|
||||
from ..models.lora import apply_lora
|
||||
from ..models.lora import apply_lora, lora_load_state_dict
|
||||
|
||||
import torch
|
||||
import re
|
||||
|
@ -37,7 +37,8 @@ def load_engines(training=True):
|
|||
lora = None
|
||||
|
||||
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
|
||||
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||
loads_state_dict = cfg.trainer.load_state_dict # or inferencing
|
||||
|
||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
|
@ -45,8 +46,12 @@ def load_engines(training=True):
|
|||
|
||||
# actually use the lora-specific checkpoint if available
|
||||
if cfg.lora is not None:
|
||||
lora = cfg.lora
|
||||
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
|
||||
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest"
|
||||
|
||||
# to handle the issue of training with deepspeed, but inferencing with local
|
||||
if checkpoint_path.exists() and backend == "local":
|
||||
tag = open(checkpoint_path).read()
|
||||
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / tag / "state.pth"
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
print("Checkpoint missing, but weights found.")
|
||||
|
@ -67,12 +72,11 @@ def load_engines(training=True):
|
|||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||
ddp = cfg.trainer.ddp
|
||||
|
||||
engine_class = LocalEngine if backend == "local" or inferencing else Engine
|
||||
engine_class = LocalEngine if backend == "local" else Engine
|
||||
|
||||
# apply model replacers
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
|
@ -152,7 +156,7 @@ def load_engines(training=True):
|
|||
stats = state["stats"]
|
||||
|
||||
# do not load stats if we're training a LoRA
|
||||
if "lora" in state or cfg.lora is not None or cfg.trainer.restart_step_count:
|
||||
if cfg.lora is not None or cfg.trainer.restart_step_count:
|
||||
stats = None
|
||||
|
||||
if "module" in state:
|
||||
|
@ -184,11 +188,13 @@ def load_engines(training=True):
|
|||
|
||||
# load lora weights if exists
|
||||
if cfg.lora is not None:
|
||||
lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth"
|
||||
lora_path = cfg.ckpt_dir / cfg.lora.full_name / "lora.pth"
|
||||
if lora_path.exists():
|
||||
print( "Loaded LoRA state dict:", lora_path )
|
||||
|
||||
state = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||
state = state['lora' if 'lora' in state else 'module']
|
||||
model.load_state_dict(state, strict=False)
|
||||
lora_load_state_dict( model, state )
|
||||
|
||||
# wrap if DDP is requested
|
||||
if ddp:
|
||||
|
@ -213,42 +219,12 @@ def load_engines(training=True):
|
|||
engines = Engines(engines)
|
||||
engines.setup()
|
||||
|
||||
# this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not
|
||||
if not cfg.trainer.load_state_dict:
|
||||
engines.load_checkpoint()
|
||||
engines.load_checkpoint(training=not inferencing)
|
||||
|
||||
# freeze requested params
|
||||
for name, engine in engines.items():
|
||||
engine.freeze(freeze_all=False)
|
||||
|
||||
"""
|
||||
# copy embeddings if requested
|
||||
if cfg.model._embeddings is not None:
|
||||
embeddings_path = cfg.rel_path / cfg.model._embeddings
|
||||
|
||||
if embeddings_path.exists():
|
||||
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
|
||||
if "module" in embeddings:
|
||||
embeddings = embeddings["module"]
|
||||
|
||||
frozen_params = set()
|
||||
|
||||
for k in list(embeddings.keys()):
|
||||
if re.findall(r'_emb\.', k):
|
||||
frozen_params.add(k)
|
||||
else:
|
||||
del embeddings[k]
|
||||
|
||||
engine.module.load_state_dict(embeddings, strict=False)
|
||||
|
||||
# there's definitely a much better way but I can't be assed at the moment
|
||||
for name, param in engine.module.named_parameters():
|
||||
if name not in frozen_params:
|
||||
continue
|
||||
param.requires_grad_(False)
|
||||
engine._frozen_params.add(param)
|
||||
"""
|
||||
|
||||
|
||||
#do_gc()
|
||||
|
||||
return engines
|
|
@ -164,15 +164,19 @@ class Engine():
|
|||
|
||||
if tag is None:
|
||||
tag_path = load_dir / "latest"
|
||||
|
||||
if not tag_path.exists():
|
||||
return
|
||||
|
||||
tag = open(tag_path).read()
|
||||
|
||||
load_path = load_dir / tag / "state.pth"
|
||||
|
||||
if not load_path.exists():
|
||||
return
|
||||
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
|
||||
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
|
||||
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
||||
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
||||
|
@ -393,18 +397,19 @@ class Engines(dict[str, Engine]):
|
|||
p.unlink()
|
||||
d.rmdir()
|
||||
|
||||
def load_checkpoint(self, tag=None):
|
||||
def load_checkpoint(self, tag=None, training=True):
|
||||
if not tag:
|
||||
tag = cfg.trainer.load_tag
|
||||
|
||||
for name, engine in self.items():
|
||||
load_dir = cfg.ckpt_dir / name
|
||||
|
||||
engine.load_checkpoint(
|
||||
tag=tag,
|
||||
load_dir=load_dir,
|
||||
load_module_strict=cfg.trainer.strict_loading,
|
||||
load_optimizer_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
|
||||
load_lr_scheduler_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
|
||||
load_optimizer_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
|
||||
load_lr_scheduler_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
|
||||
load_module_only=cfg.trainer.load_module_only,
|
||||
)
|
||||
if cfg.trainer.restart_step_count:
|
||||
|
|
|
@ -118,7 +118,6 @@ class ParameterizedLoRA(nn.Module):
|
|||
nn.init.zeros_( self.lora_B )
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
print( self.enabled, x.shape )
|
||||
if self.enabled:
|
||||
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
|
||||
return x
|
||||
|
@ -193,7 +192,7 @@ def apply_lora( model, register = True, merge = False, policy = None, use_parame
|
|||
else:
|
||||
setattr( model.get_submodule(name), k, replacement )
|
||||
|
||||
return model
|
||||
return enable_lora( model )
|
||||
|
||||
def enable_lora( model, mode = True ):
|
||||
for name, module in model.named_modules():
|
||||
|
|
Loading…
Reference in New Issue
Block a user