also throw exception for zero'd out tensor during training (I am very paranoid now)
This commit is contained in:
parent
ab0abd2b12
commit
15b3c20e19
|
@ -815,9 +815,15 @@ def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), valida
|
|||
|
||||
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
|
||||
|
||||
def _load_artifact(path, return_metadata=False, return_artifact=False) -> Tensor:
|
||||
def _load_artifact(path, return_metadata=False, return_artifact=False, validate=True) -> Tensor:
|
||||
artifact = np.load(_get_artifact_path(path), allow_pickle=True)[()]
|
||||
codes = torch.from_numpy(artifact["codes"].astype(int)).to(torch.int16)
|
||||
codes = artifact["codes"]
|
||||
|
||||
if validate and np.count_nonzero(codes) == 0:
|
||||
raise Exception(f"Artifact contains zero'd tensor: {path}")
|
||||
|
||||
codes = torch.from_numpy(codes.astype(int)).to(torch.int16)
|
||||
|
||||
# artifact was saved as a batch
|
||||
if codes.dim() == 3:
|
||||
codes = codes[0]
|
||||
|
|
Loading…
Reference in New Issue
Block a user