From b6a213bbbd2edb058e003cf9d1da1986b909d1b9 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 29 Apr 2023 00:46:36 +0000 Subject: [PATCH] removed some CPU fallback wrappers because directml seems to work now without them --- tortoise/utils/device.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index df7cbb5..c082c13 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -81,10 +81,7 @@ def get_device_vram( name=get_device_name() ): return available / (1024 ** 3) -def get_device_batch_size(name=None): - if not name: - name = get_device_name() - +def get_device_batch_size(name=get_device_name()): vram = get_device_vram(name) if vram > 14: @@ -110,6 +107,8 @@ def get_device_count(name=get_device_name()): return 1 +# if you're getting errors make sure you've updated your torch-directml, and if you're still getting errors then you can uncomment the below block +""" if has_dml(): _cumsum = torch.cumsum _repeat_interleave = torch.repeat_interleave @@ -127,4 +126,5 @@ if has_dml(): torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) ) torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) ) torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) ) - torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) ) \ No newline at end of file + torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) ) +""" \ No newline at end of file