diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py
index df7cbb5..c082c13 100755
--- a/tortoise/utils/device.py
+++ b/tortoise/utils/device.py
@@ -81,10 +81,7 @@ def get_device_vram( name=get_device_name() ):
 
     return available / (1024 ** 3)
 
-def get_device_batch_size(name=None):
-    if not name:
-        name = get_device_name()
-        
+def get_device_batch_size(name=get_device_name()):
     vram = get_device_vram(name)
 
     if vram > 14:
@@ -110,6 +107,8 @@ def get_device_count(name=get_device_name()):
     return 1
 
 
+# if you're getting errors make sure you've updated your torch-directml, and if you're still getting errors then you can uncomment the below block
+"""
 if has_dml():
     _cumsum = torch.cumsum
     _repeat_interleave = torch.repeat_interleave
@@ -127,4 +126,5 @@ if has_dml():
     torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) )
     torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) )
     torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) )
-    torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) )
\ No newline at end of file
+    torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) )
+"""
\ No newline at end of file