also throw exception for zero'd out tensor during training (I am very paranoid now)

This commit is contained in:
mrq 2025-02-22 14:09:41 -06:00
parent ab0abd2b12
commit 15b3c20e19

View File

@ -815,9 +815,15 @@ def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), valida
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
def _load_artifact(path, return_metadata=False, return_artifact=False) -> Tensor:
def _load_artifact(path, return_metadata=False, return_artifact=False, validate=True) -> Tensor:
artifact = np.load(_get_artifact_path(path), allow_pickle=True)[()]
codes = torch.from_numpy(artifact["codes"].astype(int)).to(torch.int16)
codes = artifact["codes"]
if validate and np.count_nonzero(codes) == 0:
raise Exception(f"Artifact contains zero'd tensor: {path}")
codes = torch.from_numpy(codes.astype(int)).to(torch.int16)
# artifact was saved as a batch
if codes.dim() == 3:
codes = codes[0]