diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py
index 7b16e7f..8e065b8 100755
--- a/tortoise/utils/device.py
+++ b/tortoise/utils/device.py
@@ -1,127 +1,128 @@
-import torch
-import psutil
-import importlib
-
-DEVICE_OVERRIDE = None
-DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
-
-from inspect import currentframe, getframeinfo
-import gc
-
-def do_gc():
-    gc.collect()
-    try:
-        torch.cuda.empty_cache()
-    except Exception as e:
-        pass
-
-def print_stats(collect=False):
-    cf = currentframe().f_back
-    msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}'
-
-    if collect:
-        do_gc()
-
-    tot = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
-    res = torch.cuda.memory_reserved(0) / (1024 ** 3)
-    alloc = torch.cuda.memory_allocated(0) / (1024 ** 3)
-    print("[{}] Total: {:.3f} | Reserved: {:.3f} | Allocated: {:.3f} | Free: {:.3f}".format( msg, tot, res, alloc, tot-res ))
-
-
-def has_dml():
-    loader = importlib.find_loader('torch_directml')
-    if loader is None:
-        return False
-    
-    import torch_directml
-    return torch_directml.is_available()
-
-def set_device_name(name):
-    global DEVICE_OVERRIDE
-    DEVICE_OVERRIDE = name
-
-def get_device_name(attempt_gc=True):
-    global DEVICE_OVERRIDE
-    if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "":
-        return DEVICE_OVERRIDE
-
-    name = 'cpu'
-
-    if torch.cuda.is_available():
-        name = 'cuda'
-        if attempt_gc:
-            torch.cuda.empty_cache() # may have performance implications
-    elif has_dml():
-        name = 'dml'
-
-    return name
-
-def get_device(verbose=False):
-    name = get_device_name()
-
-    if verbose:
-        if name == 'cpu':
-            print("No hardware acceleration is available, falling back to CPU...")    
-        else:
-            print(f"Hardware acceleration found: {name}")
-
-    if name == "dml":
-        import torch_directml
-        return torch_directml.device()
-
-    return torch.device(name)
-
-def get_device_vram( name=get_device_name() ):
-    available = 1
-
-    if name == "cuda":
-        _, available = torch.cuda.mem_get_info()
-    elif name == "cpu":
-        available = psutil.virtual_memory()[4]
-
-    return available / (1024 ** 3)
-
-def get_device_batch_size(name=None):
-    vram = get_device_vram(name)
-
-    if vram > 14:
-        return 16
-    elif vram > 10:
-        return 8
-    elif vram > 7:
-        return 4
-    """
-    for k, v in DEVICE_BATCH_SIZE_MAP:
-        if vram > k:
-            return v
-    """
-    return 1
-
-def get_device_count(name=get_device_name()):
-    if name == "cuda":
-        return torch.cuda.device_count()
-    if name == "dml":
-        import torch_directml
-        return torch_directml.device_count()
-
-    return 1
-
-
-if has_dml():
-    _cumsum = torch.cumsum
-    _repeat_interleave = torch.repeat_interleave
-    _multinomial = torch.multinomial
-    
-    _Tensor_new = torch.Tensor.new
-    _Tensor_cumsum = torch.Tensor.cumsum
-    _Tensor_repeat_interleave = torch.Tensor.repeat_interleave
-    _Tensor_multinomial = torch.Tensor.multinomial
-
-    torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) )
-    torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) )
-    torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) )
-    
-    torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) )
-    torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) )
-    torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) )
+import torch
+import psutil
+import importlib
+
+DEVICE_OVERRIDE = None
+DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
+
+from inspect import currentframe, getframeinfo
+import gc
+
+def do_gc():
+    gc.collect()
+    try:
+        torch.cuda.empty_cache()
+    except Exception as e:
+        pass
+
+def print_stats(collect=False):
+    cf = currentframe().f_back
+    msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}'
+
+    if collect:
+        do_gc()
+
+    tot = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
+    res = torch.cuda.memory_reserved(0) / (1024 ** 3)
+    alloc = torch.cuda.memory_allocated(0) / (1024 ** 3)
+    print("[{}] Total: {:.3f} | Reserved: {:.3f} | Allocated: {:.3f} | Free: {:.3f}".format( msg, tot, res, alloc, tot-res ))
+
+
+def has_dml():
+    loader = importlib.find_loader('torch_directml')
+    if loader is None:
+        return False
+    
+    import torch_directml
+    return torch_directml.is_available()
+
+def set_device_name(name):
+    global DEVICE_OVERRIDE
+    DEVICE_OVERRIDE = name
+
+def get_device_name(attempt_gc=True):
+    global DEVICE_OVERRIDE
+    if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "":
+        return DEVICE_OVERRIDE
+
+    name = 'cpu'
+
+    if torch.cuda.is_available():
+        name = 'cuda'
+        if attempt_gc:
+            torch.cuda.empty_cache() # may have performance implications
+    elif has_dml():
+        name = 'dml'
+
+    return name
+
+def get_device(verbose=False):
+    name = get_device_name()
+
+    if verbose:
+        if name == 'cpu':
+            print("No hardware acceleration is available, falling back to CPU...")    
+        else:
+            print(f"Hardware acceleration found: {name}")
+
+    if name == "dml":
+        import torch_directml
+        return torch_directml.device()
+
+    return torch.device(name)
+
+def get_device_vram( name=get_device_name() ):
+    available = 1
+
+    if name == "cuda":
+        _, available = torch.cuda.mem_get_info()
+    elif name == "cpu":
+        available = psutil.virtual_memory()[4]
+
+    return available / (1024 ** 3)
+
+def get_device_batch_size(name=None):
+    name = get_device_name()
+    vram = get_device_vram(name)
+
+    if vram > 14:
+        return 16
+    elif vram > 10:
+        return 8
+    elif vram > 7:
+        return 4
+    """
+    for k, v in DEVICE_BATCH_SIZE_MAP:
+        if vram > k:
+            return v
+    """
+    return 1
+
+def get_device_count(name=get_device_name()):
+    if name == "cuda":
+        return torch.cuda.device_count()
+    if name == "dml":
+        import torch_directml
+        return torch_directml.device_count()
+
+    return 1
+
+
+if has_dml():
+    _cumsum = torch.cumsum
+    _repeat_interleave = torch.repeat_interleave
+    _multinomial = torch.multinomial
+    
+    _Tensor_new = torch.Tensor.new
+    _Tensor_cumsum = torch.Tensor.cumsum
+    _Tensor_repeat_interleave = torch.Tensor.repeat_interleave
+    _Tensor_multinomial = torch.Tensor.multinomial
+
+    torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) )
+    torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) )
+    torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) )
+    
+    torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) )
+    torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) )
+    torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) )
     torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) )
\ No newline at end of file