diff --git a/vall_e/data.py b/vall_e/data.py index 2450699..f4f1787 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -751,10 +751,7 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): metadata = {} if cfg.dataset.use_metadata and metadata_path.exists(): - try: - metadata = json_read( metadata_path ) - except Exception as e: - return {} + metadata = json_read( metadata_path ) if len(metadata) == 0: return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate ) @@ -820,7 +817,7 @@ def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), valida if isinstance(path, str): path = Path(path) - return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else [] + return [ str(p) for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else [] def _load_artifact(path, return_metadata=False, return_artifact=False, validate=True) -> Tensor: artifact = np.load(_get_artifact_path(path), allow_pickle=True)[()] diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 5b312ee..0f3f8c8 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -586,7 +586,7 @@ class Engines(dict[str, Engine]): loss_scale = 1 if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None: loss_scale = engine.optimizer.loss_scale - elif engine.loss_scaler is not None: + elif hasattr(engine, "loss_scaler") and engine.loss_scaler is not None: loss_scale = engine.loss_scaler.get_scale() if grad_norm is not None: diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 5f26df8..a8f577d 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -7,7 +7,7 @@ from typing import Literal, overload, Optional, Tuple, Union, List from torch import Tensor, nn # lazy -from transformers.models.llama.configuration_llama import LlamaConfig as Config +from transformers.models.llama.configuration_llama import LlamaConfig as BaseConfig from transformers.models.llama.modeling_llama import LlamaPreTrainedModel from transformers.modeling_utils import PreTrainedModel @@ -19,6 +19,18 @@ from transformers.activations import ACT2FN from .attention import * +class Config(BaseConfig): + def __init__( + self, + attn_mode = "sdpa", + output_norm = True, + *args, **kwargs + ): + super().__init__(*args, **kwargs) + + self.attn_mode = attn_mode + self.output_norm = output_norm + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 9b2ae90..ca7be8e 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -450,9 +450,9 @@ class Base_V2(nn.Module): is_encoder_decoder=False, is_decoder=True, #gradient_checkpointing=self.gradient_checkpointing, + output_norm = not per_level_normalization, # moves the LN out to the decoder + attn_mode = attention_backend, ) - self.model_config.output_norm = not per_level_normalization # moves the LN out to the decoder - self.model_config.attn_mode = attention_backend self.model = LlamaModel(self.model_config) if self.gradient_checkpointing and not self.model.gradient_checkpointing: