ugh
This commit is contained in:
parent
32287710a2
commit
685f4faec0
|
@ -58,7 +58,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
_logger.warning("Checkpoint missing, but weights found:", load_path)
|
||||
_logger.warning(f"Checkpoint missing, but weights found: {load_path}")
|
||||
loads_state_dict = True
|
||||
|
||||
# load state early
|
||||
|
@ -204,7 +204,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
if cfg.lora is not None:
|
||||
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
if lora_path.exists():
|
||||
_logger.info( "Loaded LoRA state dict:", lora_path )
|
||||
_logger.info( f"Loaded LoRA state dict: {lora_path}" )
|
||||
|
||||
state = torch_load(lora_path, device=cfg.device)
|
||||
state = state['lora' if 'lora' in state else 'module']
|
||||
|
|
|
@ -102,7 +102,6 @@ try:
|
|||
has_flash_attn = True
|
||||
has_flash_attn_with_paged = True
|
||||
except Exception as e:
|
||||
raise e
|
||||
_logger.warning(f"Error while querying for `flash_attn` support: {str(e)}")
|
||||
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue
Block a user