forked from mrq/tortoise-tts
removed some CPU fallback wrappers because directml seems to work now without them
This commit is contained in:
parent
2f7d9ab932
commit
b6a213bbbd
|
@ -81,10 +81,7 @@ def get_device_vram( name=get_device_name() ):
|
||||||
|
|
||||||
return available / (1024 ** 3)
|
return available / (1024 ** 3)
|
||||||
|
|
||||||
def get_device_batch_size(name=None):
|
def get_device_batch_size(name=get_device_name()):
|
||||||
if not name:
|
|
||||||
name = get_device_name()
|
|
||||||
|
|
||||||
vram = get_device_vram(name)
|
vram = get_device_vram(name)
|
||||||
|
|
||||||
if vram > 14:
|
if vram > 14:
|
||||||
|
@ -110,6 +107,8 @@ def get_device_count(name=get_device_name()):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
# if you're getting errors make sure you've updated your torch-directml, and if you're still getting errors then you can uncomment the below block
|
||||||
|
"""
|
||||||
if has_dml():
|
if has_dml():
|
||||||
_cumsum = torch.cumsum
|
_cumsum = torch.cumsum
|
||||||
_repeat_interleave = torch.repeat_interleave
|
_repeat_interleave = torch.repeat_interleave
|
||||||
|
@ -127,4 +126,5 @@ if has_dml():
|
||||||
torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) )
|
torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) )
|
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||||
torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) )
|
torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||||
torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) )
|
torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||||
|
"""
|
Loading…
Reference in New Issue
Block a user