From c67c40f983997594f76b2312f92c3761e8d83715 Mon Sep 17 00:00:00 2001 From: Matthew McGoogan Date: Sat, 26 Nov 2022 23:25:16 +0000 Subject: [PATCH] torch.cuda.empty_cache() defaults to cuda:0 device unless explicitly set otherwise first. Updating torch_gc() to use the device set by --device-id if specified to avoid OOM edge cases on multi-GPU systems. --- modules/devices.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 67165bf6..93d82bbc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -44,8 +44,18 @@ def get_optimal_device(): def torch_gc(): if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + from modules import shared + + device_id = shared.cmd_opts.device_id + + if device_id is not None: + cuda_device = f"cuda:{device_id}" + else: + cuda_device = "cuda" + + with torch.cuda.device(cuda_device): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def enable_tf32():