diff --git a/modules/devices.py b/modules/devices.py
index 919048d0..52c3e7cd 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,22 +1,17 @@
-import sys, os, shlex
+import sys
 import contextlib
 import torch
 from modules import errors
-from modules.sd_hijack_utils import CondFunc
-from packaging import version
+
+if sys.platform == "darwin":
+    from modules import mac_specific
 
 
-# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
-# check `getattr` and try it for compatibility
 def has_mps() -> bool:
-    if not getattr(torch, 'has_mps', False):
+    if sys.platform != "darwin":
         return False
-    try:
-        torch.zeros(1).to(torch.device("mps"))
-        return True
-    except Exception:
-        return False
-
+    else:
+        return mac_specific.has_mps
 
 def extract_device_id(args, name):
     for x in range(len(args)):
@@ -155,36 +150,3 @@ def test_for_nans(x, where):
     message += " Use --disable-nan-check commandline argument to disable this check."
 
     raise NansException(message)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
-def cumsum_fix(input, cumsum_func, *args, **kwargs):
-    if input.device.type == 'mps':
-        output_dtype = kwargs.get('dtype', input.dtype)
-        if output_dtype == torch.int64:
-            return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
-        elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
-            return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
-    return cumsum_func(input, *args, **kwargs)
-
-
-if has_mps():
-    if version.parse(torch.__version__) < version.parse("1.13"):
-        # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
-
-        # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
-        CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
-                                                          lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
-        # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 
-        CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
-                                                                                        lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
-        # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
-        CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
-    elif version.parse(torch.__version__) > version.parse("1.13.1"):
-        cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
-        cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
-        cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
-        CondFunc('torch.cumsum', cumsum_fix_func, None)
-        CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
-        CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
-
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
new file mode 100644
index 00000000..e39d670e
--- /dev/null
+++ b/modules/mac_specific.py
@@ -0,0 +1,56 @@
+import torch
+from modules import paths
+from modules.sd_hijack_utils import CondFunc
+from packaging import version
+
+
+device = None
+
+
+# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
+# check `getattr` and try it for compatibility
+def check_for_mps() -> bool:
+    if not getattr(torch, 'has_mps', False):
+        return False
+    try:
+        torch.zeros(1).to(torch.device("mps"))
+        return True
+    except Exception:
+        return False
+has_mps = check_for_mps()
+
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
+def cumsum_fix(input, cumsum_func, *args, **kwargs):
+    if input.device.type == 'mps':
+        output_dtype = kwargs.get('dtype', input.dtype)
+        if output_dtype == torch.int64:
+            return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
+        elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
+            return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
+    return cumsum_func(input, *args, **kwargs)
+
+
+if has_mps:
+    # MPS fix for randn in torchsde
+    CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
+
+    if version.parse(torch.__version__) < version.parse("1.13"):
+        # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
+
+        # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
+        CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
+                                                          lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
+        # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 
+        CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
+                                                                                        lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
+        # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
+        CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
+    elif version.parse(torch.__version__) > version.parse("1.13.1"):
+        cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
+        cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
+        cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
+        CondFunc('torch.cumsum', cumsum_fix_func, None)
+        CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
+        CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
+
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 3c03d442..a1aac7cf 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -2,7 +2,6 @@ from collections import namedtuple
 import numpy as np
 import torch
 from PIL import Image
-import torchsde._brownian.brownian_interval
 from modules import devices, processing, images, sd_vae_approx
 
 from modules.shared import opts, state
@@ -61,18 +60,3 @@ def store_latent(decoded):
 
 class InterruptedException(BaseException):
     pass
-
-
-# MPS fix for randn in torchsde
-# XXX move this to separate file for MPS
-def torchsde_randn(size, dtype, device, seed):
-    if device.type == 'mps':
-        generator = torch.Generator(devices.cpu).manual_seed(int(seed))
-        return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
-    else:
-        generator = torch.Generator(device).manual_seed(int(seed))
-        return torch.randn(size, dtype=dtype, device=device, generator=generator)
-
-
-torchsde._brownian.brownian_interval._randn = torchsde_randn
-
diff --git a/modules/shared.py b/modules/shared.py
index 5600d480..59f12cd8 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -145,6 +145,9 @@ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.devic
     (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
 
 device = devices.device
+if sys.platform == "darwin":
+    from modules import mac_specific
+    mac_specific.device = device
 weight_load_location = None if cmd_opts.lowram else "cpu"
 
 batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)