Implement correct XPU device count

Forgot to do that
This commit is contained in:
a-One-Fan 2023-05-04 21:14:07 +03:00
parent 44d2dcbb19
commit 8618922a33

View File

@ -128,6 +128,9 @@ def get_device_count(name=get_device_name()):
if name == "dml":
import torch_directml
return torch_directml.device_count()
if name == "xpu":
import intel_extension_for_pytorch
return torch.xpu.device_count()
return 1