This commit is contained in:
mrq 2025-02-28 18:53:07 -06:00
parent 4e7d885542
commit b97faa8173
4 changed files with 18 additions and 9 deletions

View File

@ -751,10 +751,7 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
metadata = {}
if cfg.dataset.use_metadata and metadata_path.exists():
try:
metadata = json_read( metadata_path )
except Exception as e:
return {}
metadata = json_read( metadata_path )
if len(metadata) == 0:
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
@ -820,7 +817,7 @@ def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), valida
if isinstance(path, str):
path = Path(path)
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
return [ str(p) for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
def _load_artifact(path, return_metadata=False, return_artifact=False, validate=True) -> Tensor:
artifact = np.load(_get_artifact_path(path), allow_pickle=True)[()]

View File

@ -586,7 +586,7 @@ class Engines(dict[str, Engine]):
loss_scale = 1
if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None:
loss_scale = engine.optimizer.loss_scale
elif engine.loss_scaler is not None:
elif hasattr(engine, "loss_scaler") and engine.loss_scaler is not None:
loss_scale = engine.loss_scaler.get_scale()
if grad_norm is not None:

View File

@ -7,7 +7,7 @@ from typing import Literal, overload, Optional, Tuple, Union, List
from torch import Tensor, nn
# lazy
from transformers.models.llama.configuration_llama import LlamaConfig as Config
from transformers.models.llama.configuration_llama import LlamaConfig as BaseConfig
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
from transformers.modeling_utils import PreTrainedModel
@ -19,6 +19,18 @@ from transformers.activations import ACT2FN
from .attention import *
class Config(BaseConfig):
def __init__(
self,
attn_mode = "sdpa",
output_norm = True,
*args, **kwargs
):
super().__init__(*args, **kwargs)
self.attn_mode = attn_mode
self.output_norm = output_norm
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape

View File

@ -450,9 +450,9 @@ class Base_V2(nn.Module):
is_encoder_decoder=False,
is_decoder=True,
#gradient_checkpointing=self.gradient_checkpointing,
output_norm = not per_level_normalization, # moves the LN out to the decoder
attn_mode = attention_backend,
)
self.model_config.output_norm = not per_level_normalization # moves the LN out to the decoder
self.model_config.attn_mode = attention_backend
self.model = LlamaModel(self.model_config)
if self.gradient_checkpointing and not self.model.gradient_checkpointing: