Merge pull request #651 from EyeDeck/master
Add some error handling for VRAM monitor
This commit is contained in:
commit
9e892d90ce
|
@ -22,6 +22,13 @@ class MemUsageMonitor(threading.Thread):
|
||||||
self.run_flag = threading.Event()
|
self.run_flag = threading.Event()
|
||||||
self.data = defaultdict(int)
|
self.data = defaultdict(int)
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.cuda.mem_get_info()
|
||||||
|
torch.cuda.memory_stats(self.device)
|
||||||
|
except Exception as e: # AMD or whatever
|
||||||
|
print(f"Warning: caught exception '{e}', memory monitor disabled")
|
||||||
|
self.disabled = True
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
return
|
return
|
||||||
|
@ -62,6 +69,7 @@ class MemUsageMonitor(threading.Thread):
|
||||||
self.run_flag.set()
|
self.run_flag.set()
|
||||||
|
|
||||||
def read(self):
|
def read(self):
|
||||||
|
if not self.disabled:
|
||||||
free, total = torch.cuda.mem_get_info()
|
free, total = torch.cuda.mem_get_info()
|
||||||
self.data["total"] = total
|
self.data["total"] = total
|
||||||
|
|
||||||
|
|
|
@ -119,6 +119,8 @@ def save_files(js_data, images, index):
|
||||||
|
|
||||||
def wrap_gradio_call(func):
|
def wrap_gradio_call(func):
|
||||||
def f(*args, **kwargs):
|
def f(*args, **kwargs):
|
||||||
|
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
||||||
|
if run_memmon:
|
||||||
shared.mem_mon.monitor()
|
shared.mem_mon.monitor()
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
|
|
||||||
|
@ -136,17 +138,20 @@ def wrap_gradio_call(func):
|
||||||
|
|
||||||
elapsed = time.perf_counter() - t
|
elapsed = time.perf_counter() - t
|
||||||
|
|
||||||
mem_stats = {k: -(v//-(1024*1024)) for k,v in shared.mem_mon.stop().items()}
|
if run_memmon:
|
||||||
|
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
||||||
active_peak = mem_stats['active_peak']
|
active_peak = mem_stats['active_peak']
|
||||||
reserved_peak = mem_stats['reserved_peak']
|
reserved_peak = mem_stats['reserved_peak']
|
||||||
sys_peak = '?' if opts.memmon_poll_rate <= 0 else mem_stats['system_peak']
|
sys_peak = mem_stats['system_peak']
|
||||||
sys_total = mem_stats['total']
|
sys_total = mem_stats['total']
|
||||||
sys_pct = '?' if opts.memmon_poll_rate <= 0 else round(sys_peak/sys_total * 100, 2)
|
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
|
||||||
vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.
" \
|
vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.
" \
|
||||||
"Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.
" \
|
"Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.
" \
|
||||||
"Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)."
|
"Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)."
|
||||||
|
|
||||||
vram_html = '' if opts.memmon_poll_rate == 0 else f"<p class='vram' title='{vram_tooltip}'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
vram_html = f"<p class='vram' title='{vram_tooltip}'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
||||||
|
else:
|
||||||
|
vram_html = ''
|
||||||
|
|
||||||
# last item is always HTML
|
# last item is always HTML
|
||||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
|
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user