diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index eddfcea..b4d51ea 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -128,6 +128,9 @@ def get_device_count(name=get_device_name()): if name == "dml": import torch_directml return torch_directml.device_count() + if name == "xpu": + import intel_extension_for_pytorch + return torch.xpu.device_count() return 1