some weird fixes for an equally weird regression with LoRA loading
This commit is contained in:
parent
e33c4b0cb1
commit
188d116222
|
@ -4,12 +4,14 @@ from .inference import TTS
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
|
|
||||||
def path_list(arg):
|
def path_list(arg):
|
||||||
|
if not arg:
|
||||||
|
return None
|
||||||
return [Path(p) for p in arg.split(";")]
|
return [Path(p) for p in arg.split(";")]
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser("VALL-E TTS")
|
parser = argparse.ArgumentParser("VALL-E TTS")
|
||||||
parser.add_argument("text")
|
parser.add_argument("text")
|
||||||
parser.add_argument("references", type=path_list)
|
parser.add_argument("references", type=path_list, default=None)
|
||||||
parser.add_argument("--language", type=str, default="en")
|
parser.add_argument("--language", type=str, default="en")
|
||||||
parser.add_argument("--out-path", type=Path, default=None)
|
parser.add_argument("--out-path", type=Path, default=None)
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
||||||
|
|
||||||
from ..models import get_models, get_model
|
from ..models import get_models, get_model
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
from ..models.lora import apply_lora
|
from ..models.lora import apply_lora, lora_load_state_dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
@ -37,16 +37,21 @@ def load_engines(training=True):
|
||||||
lora = None
|
lora = None
|
||||||
|
|
||||||
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
|
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
|
||||||
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||||
|
loads_state_dict = cfg.trainer.load_state_dict # or inferencing
|
||||||
|
|
||||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
|
|
||||||
# actually use the lora-specific checkpoint if available
|
# actually use the lora-specific checkpoint if available
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
lora = cfg.lora
|
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest"
|
||||||
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
|
|
||||||
|
# to handle the issue of training with deepspeed, but inferencing with local
|
||||||
|
if checkpoint_path.exists() and backend == "local":
|
||||||
|
tag = open(checkpoint_path).read()
|
||||||
|
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / tag / "state.pth"
|
||||||
|
|
||||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||||
print("Checkpoint missing, but weights found.")
|
print("Checkpoint missing, but weights found.")
|
||||||
|
@ -67,12 +72,11 @@ def load_engines(training=True):
|
||||||
optimizer = None
|
optimizer = None
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
|
||||||
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||||
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||||
ddp = cfg.trainer.ddp
|
ddp = cfg.trainer.ddp
|
||||||
|
|
||||||
engine_class = LocalEngine if backend == "local" or inferencing else Engine
|
engine_class = LocalEngine if backend == "local" else Engine
|
||||||
|
|
||||||
# apply model replacers
|
# apply model replacers
|
||||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||||
|
@ -152,7 +156,7 @@ def load_engines(training=True):
|
||||||
stats = state["stats"]
|
stats = state["stats"]
|
||||||
|
|
||||||
# do not load stats if we're training a LoRA
|
# do not load stats if we're training a LoRA
|
||||||
if "lora" in state or cfg.lora is not None or cfg.trainer.restart_step_count:
|
if cfg.lora is not None or cfg.trainer.restart_step_count:
|
||||||
stats = None
|
stats = None
|
||||||
|
|
||||||
if "module" in state:
|
if "module" in state:
|
||||||
|
@ -182,13 +186,15 @@ def load_engines(training=True):
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
# load lora weights if exists
|
# load lora weights if exists
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth"
|
lora_path = cfg.ckpt_dir / cfg.lora.full_name / "lora.pth"
|
||||||
if lora_path.exists():
|
if lora_path.exists():
|
||||||
state = torch.load(lora_path, map_location=torch.device(cfg.device))
|
print( "Loaded LoRA state dict:", lora_path )
|
||||||
state = state['lora' if 'lora' in state else 'module']
|
|
||||||
model.load_state_dict(state, strict=False)
|
state = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||||
|
state = state['lora' if 'lora' in state else 'module']
|
||||||
|
lora_load_state_dict( model, state )
|
||||||
|
|
||||||
# wrap if DDP is requested
|
# wrap if DDP is requested
|
||||||
if ddp:
|
if ddp:
|
||||||
|
@ -213,42 +219,12 @@ def load_engines(training=True):
|
||||||
engines = Engines(engines)
|
engines = Engines(engines)
|
||||||
engines.setup()
|
engines.setup()
|
||||||
|
|
||||||
|
# this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not
|
||||||
if not cfg.trainer.load_state_dict:
|
if not cfg.trainer.load_state_dict:
|
||||||
engines.load_checkpoint()
|
engines.load_checkpoint(training=not inferencing)
|
||||||
|
|
||||||
# freeze requested params
|
# freeze requested params
|
||||||
for name, engine in engines.items():
|
for name, engine in engines.items():
|
||||||
engine.freeze(freeze_all=False)
|
engine.freeze(freeze_all=False)
|
||||||
|
|
||||||
"""
|
|
||||||
# copy embeddings if requested
|
|
||||||
if cfg.model._embeddings is not None:
|
|
||||||
embeddings_path = cfg.rel_path / cfg.model._embeddings
|
|
||||||
|
|
||||||
if embeddings_path.exists():
|
|
||||||
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
|
|
||||||
if "module" in embeddings:
|
|
||||||
embeddings = embeddings["module"]
|
|
||||||
|
|
||||||
frozen_params = set()
|
|
||||||
|
|
||||||
for k in list(embeddings.keys()):
|
|
||||||
if re.findall(r'_emb\.', k):
|
|
||||||
frozen_params.add(k)
|
|
||||||
else:
|
|
||||||
del embeddings[k]
|
|
||||||
|
|
||||||
engine.module.load_state_dict(embeddings, strict=False)
|
|
||||||
|
|
||||||
# there's definitely a much better way but I can't be assed at the moment
|
|
||||||
for name, param in engine.module.named_parameters():
|
|
||||||
if name not in frozen_params:
|
|
||||||
continue
|
|
||||||
param.requires_grad_(False)
|
|
||||||
engine._frozen_params.add(param)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
#do_gc()
|
|
||||||
|
|
||||||
return engines
|
return engines
|
|
@ -164,15 +164,19 @@ class Engine():
|
||||||
|
|
||||||
if tag is None:
|
if tag is None:
|
||||||
tag_path = load_dir / "latest"
|
tag_path = load_dir / "latest"
|
||||||
|
|
||||||
if not tag_path.exists():
|
if not tag_path.exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
tag = open(tag_path).read()
|
tag = open(tag_path).read()
|
||||||
|
|
||||||
load_path = load_dir / tag / "state.pth"
|
load_path = load_dir / tag / "state.pth"
|
||||||
|
|
||||||
if not load_path.exists():
|
if not load_path.exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||||
|
|
||||||
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
|
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
|
||||||
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
||||||
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
||||||
|
@ -393,18 +397,19 @@ class Engines(dict[str, Engine]):
|
||||||
p.unlink()
|
p.unlink()
|
||||||
d.rmdir()
|
d.rmdir()
|
||||||
|
|
||||||
def load_checkpoint(self, tag=None):
|
def load_checkpoint(self, tag=None, training=True):
|
||||||
if not tag:
|
if not tag:
|
||||||
tag = cfg.trainer.load_tag
|
tag = cfg.trainer.load_tag
|
||||||
|
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
load_dir = cfg.ckpt_dir / name
|
load_dir = cfg.ckpt_dir / name
|
||||||
|
|
||||||
engine.load_checkpoint(
|
engine.load_checkpoint(
|
||||||
tag=tag,
|
tag=tag,
|
||||||
load_dir=load_dir,
|
load_dir=load_dir,
|
||||||
load_module_strict=cfg.trainer.strict_loading,
|
load_module_strict=cfg.trainer.strict_loading,
|
||||||
load_optimizer_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
|
load_optimizer_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
|
||||||
load_lr_scheduler_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
|
load_lr_scheduler_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
|
||||||
load_module_only=cfg.trainer.load_module_only,
|
load_module_only=cfg.trainer.load_module_only,
|
||||||
)
|
)
|
||||||
if cfg.trainer.restart_step_count:
|
if cfg.trainer.restart_step_count:
|
||||||
|
|
|
@ -118,7 +118,6 @@ class ParameterizedLoRA(nn.Module):
|
||||||
nn.init.zeros_( self.lora_B )
|
nn.init.zeros_( self.lora_B )
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
print( self.enabled, x.shape )
|
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
|
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
|
||||||
return x
|
return x
|
||||||
|
@ -193,7 +192,7 @@ def apply_lora( model, register = True, merge = False, policy = None, use_parame
|
||||||
else:
|
else:
|
||||||
setattr( model.get_submodule(name), k, replacement )
|
setattr( model.get_submodule(name), k, replacement )
|
||||||
|
|
||||||
return model
|
return enable_lora( model )
|
||||||
|
|
||||||
def enable_lora( model, mode = True ):
|
def enable_lora( model, mode = True ):
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user