some weird fixes for an equally weird regression with LoRA loading

This commit is contained in:
mrq 2024-07-22 20:47:24 -05:00
parent e33c4b0cb1
commit 188d116222
4 changed files with 35 additions and 53 deletions

View File

@ -4,12 +4,14 @@ from .inference import TTS
from .config import cfg from .config import cfg
def path_list(arg): def path_list(arg):
if not arg:
return None
return [Path(p) for p in arg.split(";")] return [Path(p) for p in arg.split(";")]
def main(): def main():
parser = argparse.ArgumentParser("VALL-E TTS") parser = argparse.ArgumentParser("VALL-E TTS")
parser.add_argument("text") parser.add_argument("text")
parser.add_argument("references", type=path_list) parser.add_argument("references", type=path_list, default=None)
parser.add_argument("--language", type=str, default="en") parser.add_argument("--language", type=str, default="en")
parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--out-path", type=Path, default=None)

View File

@ -12,7 +12,7 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
from ..models import get_models, get_model from ..models import get_models, get_model
from ..utils import wrapper as ml from ..utils import wrapper as ml
from ..models.lora import apply_lora from ..models.lora import apply_lora, lora_load_state_dict
import torch import torch
import re import re
@ -37,16 +37,21 @@ def load_engines(training=True):
lora = None lora = None
inferencing = cfg.mode == "inferencing" or not model.config.training or not training inferencing = cfg.mode == "inferencing" or not model.config.training or not training
loads_state_dict = cfg.trainer.load_state_dict or inferencing backend = cfg.inference.backend if inferencing else cfg.trainer.backend
loads_state_dict = cfg.trainer.load_state_dict # or inferencing
checkpoint_path = cfg.ckpt_dir / name / "latest" checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = cfg.ckpt_dir / name / "fp32.pth" load_path = cfg.ckpt_dir / name / "fp32.pth"
# actually use the lora-specific checkpoint if available # actually use the lora-specific checkpoint if available
if cfg.lora is not None: if cfg.lora is not None:
lora = cfg.lora checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest"
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
# to handle the issue of training with deepspeed, but inferencing with local
if checkpoint_path.exists() and backend == "local":
tag = open(checkpoint_path).read()
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / tag / "state.pth"
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found.") print("Checkpoint missing, but weights found.")
@ -67,12 +72,11 @@ def load_engines(training=True):
optimizer = None optimizer = None
lr_scheduler = None lr_scheduler = None
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp amp = cfg.inference.amp if inferencing else cfg.trainer.amp
ddp = cfg.trainer.ddp ddp = cfg.trainer.ddp
engine_class = LocalEngine if backend == "local" or inferencing else Engine engine_class = LocalEngine if backend == "local" else Engine
# apply model replacers # apply model replacers
if cfg.optimizations.replace and cfg.optimizations.linear: if cfg.optimizations.replace and cfg.optimizations.linear:
@ -152,7 +156,7 @@ def load_engines(training=True):
stats = state["stats"] stats = state["stats"]
# do not load stats if we're training a LoRA # do not load stats if we're training a LoRA
if "lora" in state or cfg.lora is not None or cfg.trainer.restart_step_count: if cfg.lora is not None or cfg.trainer.restart_step_count:
stats = None stats = None
if "module" in state: if "module" in state:
@ -182,13 +186,15 @@ def load_engines(training=True):
model.load_state_dict(state, strict=cfg.trainer.strict_loading) model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# load lora weights if exists # load lora weights if exists
if cfg.lora is not None: if cfg.lora is not None:
lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth" lora_path = cfg.ckpt_dir / cfg.lora.full_name / "lora.pth"
if lora_path.exists(): if lora_path.exists():
state = torch.load(lora_path, map_location=torch.device(cfg.device)) print( "Loaded LoRA state dict:", lora_path )
state = state['lora' if 'lora' in state else 'module']
model.load_state_dict(state, strict=False) state = torch.load(lora_path, map_location=torch.device(cfg.device))
state = state['lora' if 'lora' in state else 'module']
lora_load_state_dict( model, state )
# wrap if DDP is requested # wrap if DDP is requested
if ddp: if ddp:
@ -213,42 +219,12 @@ def load_engines(training=True):
engines = Engines(engines) engines = Engines(engines)
engines.setup() engines.setup()
# this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not
if not cfg.trainer.load_state_dict: if not cfg.trainer.load_state_dict:
engines.load_checkpoint() engines.load_checkpoint(training=not inferencing)
# freeze requested params # freeze requested params
for name, engine in engines.items(): for name, engine in engines.items():
engine.freeze(freeze_all=False) engine.freeze(freeze_all=False)
"""
# copy embeddings if requested
if cfg.model._embeddings is not None:
embeddings_path = cfg.rel_path / cfg.model._embeddings
if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
if "module" in embeddings:
embeddings = embeddings["module"]
frozen_params = set()
for k in list(embeddings.keys()):
if re.findall(r'_emb\.', k):
frozen_params.add(k)
else:
del embeddings[k]
engine.module.load_state_dict(embeddings, strict=False)
# there's definitely a much better way but I can't be assed at the moment
for name, param in engine.module.named_parameters():
if name not in frozen_params:
continue
param.requires_grad_(False)
engine._frozen_params.add(param)
"""
#do_gc()
return engines return engines

View File

@ -164,15 +164,19 @@ class Engine():
if tag is None: if tag is None:
tag_path = load_dir / "latest" tag_path = load_dir / "latest"
if not tag_path.exists(): if not tag_path.exists():
return return
tag = open(tag_path).read() tag = open(tag_path).read()
load_path = load_dir / tag / "state.pth" load_path = load_dir / tag / "state.pth"
if not load_path.exists(): if not load_path.exists():
return return
state = torch.load(load_path, map_location=torch.device(cfg.device)) state = torch.load(load_path, map_location=torch.device(cfg.device))
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step'] self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step'] self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples'] self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
@ -393,18 +397,19 @@ class Engines(dict[str, Engine]):
p.unlink() p.unlink()
d.rmdir() d.rmdir()
def load_checkpoint(self, tag=None): def load_checkpoint(self, tag=None, training=True):
if not tag: if not tag:
tag = cfg.trainer.load_tag tag = cfg.trainer.load_tag
for name, engine in self.items(): for name, engine in self.items():
load_dir = cfg.ckpt_dir / name load_dir = cfg.ckpt_dir / name
engine.load_checkpoint( engine.load_checkpoint(
tag=tag, tag=tag,
load_dir=load_dir, load_dir=load_dir,
load_module_strict=cfg.trainer.strict_loading, load_module_strict=cfg.trainer.strict_loading,
load_optimizer_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states, load_optimizer_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
load_lr_scheduler_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states, load_lr_scheduler_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
load_module_only=cfg.trainer.load_module_only, load_module_only=cfg.trainer.load_module_only,
) )
if cfg.trainer.restart_step_count: if cfg.trainer.restart_step_count:

View File

@ -118,7 +118,6 @@ class ParameterizedLoRA(nn.Module):
nn.init.zeros_( self.lora_B ) nn.init.zeros_( self.lora_B )
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
print( self.enabled, x.shape )
if self.enabled: if self.enabled:
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
return x return x
@ -193,7 +192,7 @@ def apply_lora( model, register = True, merge = False, policy = None, use_parame
else: else:
setattr( model.get_submodule(name), k, replacement ) setattr( model.get_submodule(name), k, replacement )
return model return enable_lora( model )
def enable_lora( model, mode = True ): def enable_lora( model, mode = True ):
for name, module in model.named_modules(): for name, module in model.named_modules():