diff --git a/vall_e/data.py b/vall_e/data.py index 81424fd..40e15db 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -815,9 +815,15 @@ def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), valida return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else [] -def _load_artifact(path, return_metadata=False, return_artifact=False) -> Tensor: +def _load_artifact(path, return_metadata=False, return_artifact=False, validate=True) -> Tensor: artifact = np.load(_get_artifact_path(path), allow_pickle=True)[()] - codes = torch.from_numpy(artifact["codes"].astype(int)).to(torch.int16) + codes = artifact["codes"] + + if validate and np.count_nonzero(codes) == 0: + raise Exception(f"Artifact contains zero'd tensor: {path}") + + codes = torch.from_numpy(codes.astype(int)).to(torch.int16) + # artifact was saved as a batch if codes.dim() == 3: codes = codes[0]