diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 5d1175f..4088bdc 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -8,7 +8,7 @@ if cfg.trainer.backend == "deepspeed": elif cfg.trainer.backend == "local": from .base import Engine -from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine +from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine from ..models import get_models from ..utils import wrapper as ml @@ -40,7 +40,7 @@ def load_engines(training=True): loads_state_dict = cfg.trainer.load_state_dict or inferencing ddp = cfg.trainer.ddp - engine_class = _Engine if backend == "local" or inferencing else Engine + engine_class = LocalEngine if backend == "local" or inferencing else Engine if inferencing: model.config.training = False @@ -101,9 +101,9 @@ def load_engines(training=True): warmup_steps = cfg.hyperparameters.warmup_steps ) - - + """ # set up our LR scheduler here + """ if inferencing: optimizer = None @@ -113,7 +113,7 @@ def load_engines(training=True): load_path = cfg.ckpt_dir / name / "fp32.pth" if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists(): - print("DeepSpeed checkpoint missing, but weights found.") + print("Checkpoint missing, but weights found.") loads_state_dict = True stats = None @@ -161,7 +161,7 @@ def load_engines(training=True): # deepspeed inferencing elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"): - engine_class = _Engine + engine_class = LocalEngine model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module # use base engine if requested diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index f55a642..6c4437a 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -79,7 +79,7 @@ class Engine(): raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") # freeze non-LoRA params if requested - if not self.hyper_config.frozen_params and not freeze_all: + if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None: return freeze_non_lora_weights( self.module ) for name, param in self.module.named_parameters(): diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index afa7c87..595f46c 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -66,6 +66,10 @@ class Engine(DeepSpeedEngine): if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"): raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") + # freeze non-LoRA params if requested + if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None: + return freeze_non_lora_weights( self.module ) + for name, param in self.module.named_parameters(): if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params): param.requires_grad_(False) @@ -108,8 +112,31 @@ class Engine(DeepSpeedEngine): except Exception as e: print(str(e)) - def load_loras(self): - ... + # we'll just have to live with the LoRA weights living within our main weights + # they're easy to extract anyways + def load_checkpoint(self, load_dir, **kwargs ): + # override to load the lora instead + if cfg.lora is not None: + load_dir = cfg.ckpt_dir / cfg.lora.full_name + + return super().load_checkpoint( load_dir, **kwargs ) + + def save_checkpoint(self, save_dir, **kwargs ): + # override to save the lora instead + if cfg.lora is not None: + save_dir = cfg.ckpt_dir / cfg.lora.full_name + + return super().save_checkpoint( save_dir, **kwargs ) + + def load_loras( self ): + # apply lora weights + for lora in cfg.loras: + self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy ) + + lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth" + if lora_path.exists(): + state_dict = torch.load(lora_path, map_location=torch.device(cfg.device)) + self.module = lora_load_state_dict( self.module, state_dict ) def traverse(self, *args, **kwargs): with ml.autocast(): diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index d5af171..947d0e6 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -28,7 +28,7 @@ from .distributed import ( local_leader_only, ) -from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder, load_engines +from ..engines import Engine, Engines, TrainFeeder, default_feeder, load_engines from .utils import to_device, do_gc, truncate_json from ..utils import wrapper as ml