From d7c6be6f78311bfd9e7500e26847860a3b443bb3 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 30 Jul 2024 22:15:56 -0500 Subject: [PATCH] fix weird regression in handling checkpoints when backend is local, but deepspeed checkpoints are in (it was handled with LoRA loading but not real loading...) --- vall_e/engines/__init__.py | 8 ++++---- vall_e/export.py | 3 ++- vall_e/inference.py | 4 +--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index f2029db..6d4be71 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -48,10 +48,10 @@ def load_engines(training=True): if cfg.lora is not None: checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest" - # to handle the issue of training with deepspeed, but inferencing with local - if checkpoint_path.exists() and backend == "local": - tag = open(checkpoint_path).read() - checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / tag / "state.pth" + # to handle the issue of training with deepspeed, but inferencing with local + if checkpoint_path.exists() and backend == "local": + tag = open(checkpoint_path).read() + checkpoint_path = checkpoint_path.parent / tag / "state.pth" if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): print("Checkpoint missing, but weights found.") diff --git a/vall_e/export.py b/vall_e/export.py index f0cb66c..d5d958b 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -52,6 +52,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ): state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict() state_dict['module']['lm_head.weight'] = out_proj + del state_dict['module']['classifier.bias'] return state_dict @@ -130,7 +131,7 @@ def main(): # necessary to ensure we are actually exporting the weights right cfg.inference.backend = cfg.trainer.backend - engines = load_engines(training=False) + engines = load_engines(training=False) # to ignore loading optimizer state engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback) if __name__ == "__main__": diff --git a/vall_e/inference.py b/vall_e/inference.py index 669cc5d..fa169b5 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -190,9 +190,7 @@ class TTS(): phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16) lang = to_device(lang, device=self.device, dtype=torch.uint8) - text_list = [ phns ] - proms_list = [ prom ] - + # to-do: add in case for experimental.hf model with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): if model_ar is not None: resps_list = model_ar(