forked from mrq/tortoise-tts
Implement correct XPU device count
Forgot to do that
This commit is contained in:
parent
44d2dcbb19
commit
8618922a33
|
@ -128,6 +128,9 @@ def get_device_count(name=get_device_name()):
|
|||
if name == "dml":
|
||||
import torch_directml
|
||||
return torch_directml.device_count()
|
||||
if name == "xpu":
|
||||
import intel_extension_for_pytorch
|
||||
return torch.xpu.device_count()
|
||||
|
||||
return 1
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user