DirectML kludge
This commit is contained in:
parent
ea9bd9fc74
commit
db4cac5d1f
|
@ -1,17 +1,25 @@
|
|||
import sys
|
||||
import sys, os, shlex
|
||||
import contextlib
|
||||
import torch
|
||||
from modules import errors
|
||||
|
||||
if sys.platform == "darwin":
|
||||
from modules import mac_specific
|
||||
from packaging import version
|
||||
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
||||
# check `getattr` and try it for compatibility
|
||||
def has_mps() -> bool:
|
||||
if sys.platform != "darwin":
|
||||
if not getattr(torch, 'has_mps', False):
|
||||
return False
|
||||
else:
|
||||
return mac_specific.has_mps
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def has_dml():
|
||||
import importlib
|
||||
loader = importlib.find_loader('torch_directml')
|
||||
return loader is not None
|
||||
|
||||
def extract_device_id(args, name):
|
||||
for x in range(len(args)):
|
||||
|
@ -31,16 +39,23 @@ def get_cuda_device_string():
|
|||
|
||||
|
||||
def get_optimal_device_name():
|
||||
if torch.cuda.is_available():
|
||||
return get_cuda_device_string()
|
||||
if has_dml():
|
||||
return "dml"
|
||||
|
||||
if has_mps():
|
||||
return "mps"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return get_cuda_device_string()
|
||||
|
||||
return "cpu"
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
if get_optimal_device_name() == "dml":
|
||||
import torch_directml
|
||||
return torch_directml.device()
|
||||
|
||||
return torch.device(get_optimal_device_name())
|
||||
|
||||
|
||||
|
@ -150,3 +165,75 @@ def test_for_nans(x, where):
|
|||
message += " Use --disable-nan-check commandline argument to disable this check."
|
||||
|
||||
raise NansException(message)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
orig_tensor_to = torch.Tensor.to
|
||||
def tensor_to_fix(self, *args, **kwargs):
|
||||
if self.device.type != 'mps' and \
|
||||
((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
|
||||
(isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
|
||||
self = self.contiguous()
|
||||
return orig_tensor_to(self, *args, **kwargs)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
orig_layer_norm = torch.nn.functional.layer_norm
|
||||
def layer_norm_fix(*args, **kwargs):
|
||||
if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
|
||||
args = list(args)
|
||||
args[0] = args[0].contiguous()
|
||||
return orig_layer_norm(*args, **kwargs)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||
orig_tensor_numpy = torch.Tensor.numpy
|
||||
def numpy_fix(self, *args, **kwargs):
|
||||
if self.requires_grad:
|
||||
self = self.detach()
|
||||
return orig_tensor_numpy(self, *args, **kwargs)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||
orig_cumsum = torch.cumsum
|
||||
orig_Tensor_cumsum = torch.Tensor.cumsum
|
||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
if input.device.type == 'mps':
|
||||
output_dtype = kwargs.get('dtype', input.dtype)
|
||||
if output_dtype == torch.int64:
|
||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||
return cumsum_func(input, *args, **kwargs)
|
||||
|
||||
|
||||
if has_mps():
|
||||
if version.parse(torch.__version__) < version.parse("1.13"):
|
||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
torch.Tensor.to = tensor_to_fix
|
||||
torch.nn.functional.layer_norm = layer_norm_fix
|
||||
torch.Tensor.numpy = numpy_fix
|
||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
||||
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
||||
|
||||
if has_dml():
|
||||
_cumsum = torch.cumsum
|
||||
_repeat_interleave = torch.repeat_interleave
|
||||
_multinomial = torch.multinomial
|
||||
|
||||
_Tensor_new = torch.Tensor.new
|
||||
_Tensor_cumsum = torch.Tensor.cumsum
|
||||
_Tensor_repeat_interleave = torch.Tensor.repeat_interleave
|
||||
_Tensor_multinomial = torch.Tensor.multinomial
|
||||
|
||||
torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) )
|
||||
torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) )
|
||||
torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) )
|
||||
|
||||
torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||
torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||
torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) )
|
|
@ -207,7 +207,7 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
||||
device = map_location or shared.weight_load_location or devices.get_optimal_device()
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||
|
|
|
@ -3,6 +3,6 @@
|
|||
set PYTHON=
|
||||
set GIT=
|
||||
set VENV_DIR=
|
||||
set COMMANDLINE_ARGS=
|
||||
set COMMANDLINE_ARGS=--skip-torch-cuda-test --precision full --no-half
|
||||
|
||||
call webui.bat
|
||||
|
|
Loading…
Reference in New Issue
Block a user