From 6c2e00ce2ab51c636bf23a8ff9b4cb125dc98e2c Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 18 Jun 2024 21:46:42 -0500 Subject: [PATCH] load exported LoRA weights if exists (to-do: make a better LoRA loading mechanism) --- tortoise_tts/engines/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tortoise_tts/engines/__init__.py b/tortoise_tts/engines/__init__.py index 38b4c28..75a26c7 100755 --- a/tortoise_tts/engines/__init__.py +++ b/tortoise_tts/engines/__init__.py @@ -166,6 +166,14 @@ def load_engines(training=True): model.load_state_dict(state, strict=cfg.trainer.strict_loading) + # load lora weights if exists + if cfg.lora is not None: + lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth" + if lora_path.exists(): + state = torch.load(lora_path, map_location=torch.device(cfg.device)) + state = state['lora' if 'lora' in state else 'module'] + model.load_state_dict(state, strict=False) + # wrap if DDP is requested if ddp: model = ddp_model(model)