diff --git a/modules/memmon.py b/modules/memmon.py index f2cac841..9fb9b687 100644 --- a/modules/memmon.py +++ b/modules/memmon.py @@ -22,6 +22,13 @@ class MemUsageMonitor(threading.Thread): self.run_flag = threading.Event() self.data = defaultdict(int) + try: + torch.cuda.mem_get_info() + torch.cuda.memory_stats(self.device) + except Exception as e: # AMD or whatever + print(f"Warning: caught exception '{e}', memory monitor disabled") + self.disabled = True + def run(self): if self.disabled: return @@ -62,13 +69,14 @@ class MemUsageMonitor(threading.Thread): self.run_flag.set() def read(self): - free, total = torch.cuda.mem_get_info() - self.data["total"] = total + if not self.disabled: + free, total = torch.cuda.mem_get_info() + self.data["total"] = total - torch_stats = torch.cuda.memory_stats(self.device) - self.data["active_peak"] = torch_stats["active_bytes.all.peak"] - self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] - self.data["system_peak"] = total - self.data["min_free"] + torch_stats = torch.cuda.memory_stats(self.device) + self.data["active_peak"] = torch_stats["active_bytes.all.peak"] + self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] + self.data["system_peak"] = total - self.data["min_free"] return self.data diff --git a/modules/ui.py b/modules/ui.py index 77f7c2ed..ada84d33 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -119,7 +119,9 @@ def save_files(js_data, images, index): def wrap_gradio_call(func): def f(*args, **kwargs): - shared.mem_mon.monitor() + run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled + if run_memmon: + shared.mem_mon.monitor() t = time.perf_counter() try: @@ -136,17 +138,20 @@ def wrap_gradio_call(func): elapsed = time.perf_counter() - t - mem_stats = {k: -(v//-(1024*1024)) for k,v in shared.mem_mon.stop().items()} - active_peak = mem_stats['active_peak'] - reserved_peak = mem_stats['reserved_peak'] - sys_peak = '?' if opts.memmon_poll_rate <= 0 else mem_stats['system_peak'] - sys_total = mem_stats['total'] - sys_pct = '?' if opts.memmon_poll_rate <= 0 else round(sys_peak/sys_total * 100, 2) - vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data. " \ - "Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data. " \ - "Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)." + if run_memmon: + mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} + active_peak = mem_stats['active_peak'] + reserved_peak = mem_stats['reserved_peak'] + sys_peak = mem_stats['system_peak'] + sys_total = mem_stats['total'] + sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) + vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data. " \ + "Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data. " \ + "Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)." - vram_html = '' if opts.memmon_poll_rate == 0 else f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" + vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" + else: + vram_html = '' # last item is always HTML res[-1] += f"

Time taken: {elapsed:.2f}s

{vram_html}
"