diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index 7b16e7f..8e065b8 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -1,127 +1,128 @@ -import torch -import psutil -import importlib - -DEVICE_OVERRIDE = None -DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)] - -from inspect import currentframe, getframeinfo -import gc - -def do_gc(): - gc.collect() - try: - torch.cuda.empty_cache() - except Exception as e: - pass - -def print_stats(collect=False): - cf = currentframe().f_back - msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}' - - if collect: - do_gc() - - tot = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) - res = torch.cuda.memory_reserved(0) / (1024 ** 3) - alloc = torch.cuda.memory_allocated(0) / (1024 ** 3) - print("[{}] Total: {:.3f} | Reserved: {:.3f} | Allocated: {:.3f} | Free: {:.3f}".format( msg, tot, res, alloc, tot-res )) - - -def has_dml(): - loader = importlib.find_loader('torch_directml') - if loader is None: - return False - - import torch_directml - return torch_directml.is_available() - -def set_device_name(name): - global DEVICE_OVERRIDE - DEVICE_OVERRIDE = name - -def get_device_name(attempt_gc=True): - global DEVICE_OVERRIDE - if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "": - return DEVICE_OVERRIDE - - name = 'cpu' - - if torch.cuda.is_available(): - name = 'cuda' - if attempt_gc: - torch.cuda.empty_cache() # may have performance implications - elif has_dml(): - name = 'dml' - - return name - -def get_device(verbose=False): - name = get_device_name() - - if verbose: - if name == 'cpu': - print("No hardware acceleration is available, falling back to CPU...") - else: - print(f"Hardware acceleration found: {name}") - - if name == "dml": - import torch_directml - return torch_directml.device() - - return torch.device(name) - -def get_device_vram( name=get_device_name() ): - available = 1 - - if name == "cuda": - _, available = torch.cuda.mem_get_info() - elif name == "cpu": - available = psutil.virtual_memory()[4] - - return available / (1024 ** 3) - -def get_device_batch_size(name=None): - vram = get_device_vram(name) - - if vram > 14: - return 16 - elif vram > 10: - return 8 - elif vram > 7: - return 4 - """ - for k, v in DEVICE_BATCH_SIZE_MAP: - if vram > k: - return v - """ - return 1 - -def get_device_count(name=get_device_name()): - if name == "cuda": - return torch.cuda.device_count() - if name == "dml": - import torch_directml - return torch_directml.device_count() - - return 1 - - -if has_dml(): - _cumsum = torch.cumsum - _repeat_interleave = torch.repeat_interleave - _multinomial = torch.multinomial - - _Tensor_new = torch.Tensor.new - _Tensor_cumsum = torch.Tensor.cumsum - _Tensor_repeat_interleave = torch.Tensor.repeat_interleave - _Tensor_multinomial = torch.Tensor.multinomial - - torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) ) - torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) ) - torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) ) - - torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) ) - torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) ) - torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) ) +import torch +import psutil +import importlib + +DEVICE_OVERRIDE = None +DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)] + +from inspect import currentframe, getframeinfo +import gc + +def do_gc(): + gc.collect() + try: + torch.cuda.empty_cache() + except Exception as e: + pass + +def print_stats(collect=False): + cf = currentframe().f_back + msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}' + + if collect: + do_gc() + + tot = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) + res = torch.cuda.memory_reserved(0) / (1024 ** 3) + alloc = torch.cuda.memory_allocated(0) / (1024 ** 3) + print("[{}] Total: {:.3f} | Reserved: {:.3f} | Allocated: {:.3f} | Free: {:.3f}".format( msg, tot, res, alloc, tot-res )) + + +def has_dml(): + loader = importlib.find_loader('torch_directml') + if loader is None: + return False + + import torch_directml + return torch_directml.is_available() + +def set_device_name(name): + global DEVICE_OVERRIDE + DEVICE_OVERRIDE = name + +def get_device_name(attempt_gc=True): + global DEVICE_OVERRIDE + if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "": + return DEVICE_OVERRIDE + + name = 'cpu' + + if torch.cuda.is_available(): + name = 'cuda' + if attempt_gc: + torch.cuda.empty_cache() # may have performance implications + elif has_dml(): + name = 'dml' + + return name + +def get_device(verbose=False): + name = get_device_name() + + if verbose: + if name == 'cpu': + print("No hardware acceleration is available, falling back to CPU...") + else: + print(f"Hardware acceleration found: {name}") + + if name == "dml": + import torch_directml + return torch_directml.device() + + return torch.device(name) + +def get_device_vram( name=get_device_name() ): + available = 1 + + if name == "cuda": + _, available = torch.cuda.mem_get_info() + elif name == "cpu": + available = psutil.virtual_memory()[4] + + return available / (1024 ** 3) + +def get_device_batch_size(name=None): + name = get_device_name() + vram = get_device_vram(name) + + if vram > 14: + return 16 + elif vram > 10: + return 8 + elif vram > 7: + return 4 + """ + for k, v in DEVICE_BATCH_SIZE_MAP: + if vram > k: + return v + """ + return 1 + +def get_device_count(name=get_device_name()): + if name == "cuda": + return torch.cuda.device_count() + if name == "dml": + import torch_directml + return torch_directml.device_count() + + return 1 + + +if has_dml(): + _cumsum = torch.cumsum + _repeat_interleave = torch.repeat_interleave + _multinomial = torch.multinomial + + _Tensor_new = torch.Tensor.new + _Tensor_cumsum = torch.Tensor.cumsum + _Tensor_repeat_interleave = torch.Tensor.repeat_interleave + _Tensor_multinomial = torch.Tensor.multinomial + + torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) ) + torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) ) + torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) ) + + torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) ) + torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) ) + torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) ) torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) ) \ No newline at end of file