fixes...
This commit is contained in:
parent
4e7d885542
commit
b97faa8173
|
@ -751,10 +751,7 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||||
try:
|
metadata = json_read( metadata_path )
|
||||||
metadata = json_read( metadata_path )
|
|
||||||
except Exception as e:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
if len(metadata) == 0:
|
if len(metadata) == 0:
|
||||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
||||||
|
@ -820,7 +817,7 @@ def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), valida
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
||||||
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
|
return [ str(p) for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
|
||||||
|
|
||||||
def _load_artifact(path, return_metadata=False, return_artifact=False, validate=True) -> Tensor:
|
def _load_artifact(path, return_metadata=False, return_artifact=False, validate=True) -> Tensor:
|
||||||
artifact = np.load(_get_artifact_path(path), allow_pickle=True)[()]
|
artifact = np.load(_get_artifact_path(path), allow_pickle=True)[()]
|
||||||
|
|
|
@ -586,7 +586,7 @@ class Engines(dict[str, Engine]):
|
||||||
loss_scale = 1
|
loss_scale = 1
|
||||||
if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None:
|
if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None:
|
||||||
loss_scale = engine.optimizer.loss_scale
|
loss_scale = engine.optimizer.loss_scale
|
||||||
elif engine.loss_scaler is not None:
|
elif hasattr(engine, "loss_scaler") and engine.loss_scaler is not None:
|
||||||
loss_scale = engine.loss_scaler.get_scale()
|
loss_scale = engine.loss_scaler.get_scale()
|
||||||
|
|
||||||
if grad_norm is not None:
|
if grad_norm is not None:
|
||||||
|
|
|
@ -7,7 +7,7 @@ from typing import Literal, overload, Optional, Tuple, Union, List
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
# lazy
|
# lazy
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig as Config
|
from transformers.models.llama.configuration_llama import LlamaConfig as BaseConfig
|
||||||
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
|
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
|
||||||
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
@ -19,6 +19,18 @@ from transformers.activations import ACT2FN
|
||||||
|
|
||||||
from .attention import *
|
from .attention import *
|
||||||
|
|
||||||
|
class Config(BaseConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
attn_mode = "sdpa",
|
||||||
|
output_norm = True,
|
||||||
|
*args, **kwargs
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.attn_mode = attn_mode
|
||||||
|
self.output_norm = output_norm
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
|
||||||
|
|
|
@ -450,9 +450,9 @@ class Base_V2(nn.Module):
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
#gradient_checkpointing=self.gradient_checkpointing,
|
#gradient_checkpointing=self.gradient_checkpointing,
|
||||||
|
output_norm = not per_level_normalization, # moves the LN out to the decoder
|
||||||
|
attn_mode = attention_backend,
|
||||||
)
|
)
|
||||||
self.model_config.output_norm = not per_level_normalization # moves the LN out to the decoder
|
|
||||||
self.model_config.attn_mode = attention_backend
|
|
||||||
self.model = LlamaModel(self.model_config)
|
self.model = LlamaModel(self.model_config)
|
||||||
|
|
||||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user