diff --git a/tortoise_tts/engines/__init__.py b/tortoise_tts/engines/__init__.py index 38b4c28..75a26c7 100755 --- a/tortoise_tts/engines/__init__.py +++ b/tortoise_tts/engines/__init__.py @@ -166,6 +166,14 @@ def load_engines(training=True): model.load_state_dict(state, strict=cfg.trainer.strict_loading) + # load lora weights if exists + if cfg.lora is not None: + lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth" + if lora_path.exists(): + state = torch.load(lora_path, map_location=torch.device(cfg.device)) + state = state['lora' if 'lora' in state else 'module'] + model.load_state_dict(state, strict=False) + # wrap if DDP is requested if ddp: model = ddp_model(model)