diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 2173dca..a973fe2 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -58,7 +58,7 @@ def load_engines(training=True, **model_kwargs): checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): - _logger.warning("Checkpoint missing, but weights found:", load_path) + _logger.warning(f"Checkpoint missing, but weights found: {load_path}") loads_state_dict = True # load state early @@ -204,7 +204,7 @@ def load_engines(training=True, **model_kwargs): if cfg.lora is not None: lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) if lora_path.exists(): - _logger.info( "Loaded LoRA state dict:", lora_path ) + _logger.info( f"Loaded LoRA state dict: {lora_path}" ) state = torch_load(lora_path, device=cfg.device) state = state['lora' if 'lora' in state else 'module'] diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 1cef053..0f2cec0 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -102,7 +102,6 @@ try: has_flash_attn = True has_flash_attn_with_paged = True except Exception as e: - raise e _logger.warning(f"Error while querying for `flash_attn` support: {str(e)}") try: