From 15b3c20e19fe7908e64c25d0909b38158f40987a Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 22 Feb 2025 14:09:41 -0600 Subject: [PATCH] also throw exception for zero'd out tensor during training (I am very paranoid now) --- vall_e/data.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 81424fd..40e15db 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -815,9 +815,15 @@ def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), valida return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else [] -def _load_artifact(path, return_metadata=False, return_artifact=False) -> Tensor: +def _load_artifact(path, return_metadata=False, return_artifact=False, validate=True) -> Tensor: artifact = np.load(_get_artifact_path(path), allow_pickle=True)[()] - codes = torch.from_numpy(artifact["codes"].astype(int)).to(torch.int16) + codes = artifact["codes"] + + if validate and np.count_nonzero(codes) == 0: + raise Exception(f"Artifact contains zero'd tensor: {path}") + + codes = torch.from_numpy(codes.astype(int)).to(torch.int16) + # artifact was saved as a batch if codes.dim() == 3: codes = codes[0]