From 1d159b1476a8015da08ea2a155fd1e93a7d7c12a Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 17 Jun 2024 13:28:18 -0500 Subject: [PATCH] updated export routine to split LoRA weights from the state dict (should work with deepspeed) --- vall_e/config.py | 4 ---- vall_e/engines/base.py | 19 ++++++++++++------- vall_e/models/lora.py | 2 +- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 990bdc4..0fa98d2 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -793,10 +793,6 @@ class Config(BaseConfig): if not training: self.dataset.use_hdf5 = False - # raise error if DeepSpeed and a LoRA is loaded, because I don't support it yet - if self.trainer.backend == "deepspeed" and self.lora is not None: - raise Exception("LoRAs are currently unsupported with deepspeed backend") - # load our HDF5 file if requested here if self.dataset.use_hdf5: self.load_hdf5() diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 6c4437a..3803098 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -331,14 +331,18 @@ class Engines(dict[str, Engine]): engine.dispatch_attribute(*args, **kwargs) def export(self, userdata={}, callback=None): - # to-do: lora exporting - if cfg.lora is not None: - return - for name, engine in self.items(): - outpath = cfg.ckpt_dir / name / "fp32.pth" + module = engine.module.state_dict() + lora = None + save_dir = cfg.ckpt_dir / name / "fp32.pth" + + if cfg.lora is not None: + lora, module = lora_get_state_dict( module, split = True ) + save_dir = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth" + state_dict = { - 'module': engine.module.state_dict(), + 'module': module, + 'lora': lora, "stats": { "global_step": engine.global_step, "micro_step": engine.micro_step, @@ -349,7 +353,8 @@ class Engines(dict[str, Engine]): } if callback: state_dict = callback( state_dict, engine.hyper_config ) - torch.save(state_dict, outpath) + + torch.save(state_dict, save_dir) print(f"Exported {name} to {outpath}") def save_checkpoint(self, tag=None): diff --git a/vall_e/models/lora.py b/vall_e/models/lora.py index 926c3a8..b0cf76c 100644 --- a/vall_e/models/lora.py +++ b/vall_e/models/lora.py @@ -10,7 +10,7 @@ import math from typing import Optional, List # to-do: set cfg to decide -USE_PARAMETRIZATION = False +USE_PARAMETRIZATION = True # LoRA Linear for replacement # Pros: simple, just needs to reuse the replace_linear and copy weights