Add GPU mem tracing module
This commit is contained in:
parent
48532a0a8a
commit
6c0e9f45c7
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -3,6 +3,7 @@ results/*
|
||||||
tb_logger/*
|
tb_logger/*
|
||||||
datasets/*
|
datasets/*
|
||||||
options/*
|
options/*
|
||||||
|
codes/*.txt
|
||||||
.vscode
|
.vscode
|
||||||
|
|
||||||
*.html
|
*.html
|
||||||
|
|
|
@ -210,7 +210,13 @@ class SRGANModel(BaseModel):
|
||||||
self.fake_GenOut = []
|
self.fake_GenOut = []
|
||||||
var_ref_skips = []
|
var_ref_skips = []
|
||||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||||
|
|
||||||
|
#from utils import gpu_mem_track
|
||||||
|
#import inspect
|
||||||
|
#gpu_tracker = gpu_mem_track.MemTracker(inspect.currentframe())
|
||||||
|
#gpu_tracker.track()
|
||||||
fake_GenOut = self.netG(var_L)
|
fake_GenOut = self.netG(var_L)
|
||||||
|
#gpu_tracker.track()
|
||||||
|
|
||||||
# Extract the image output. For generators that output skip-through connections, the master output is always
|
# Extract the image output. For generators that output skip-through connections, the master output is always
|
||||||
# the first element of the tuple.
|
# the first element of the tuple.
|
||||||
|
|
79
codes/utils/gpu_mem_track.py
Normal file
79
codes/utils/gpu_mem_track.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
import gc
|
||||||
|
import datetime
|
||||||
|
import pynvml
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class MemTracker(object):
|
||||||
|
"""
|
||||||
|
Class used to track pytorch memory usage
|
||||||
|
Arguments:
|
||||||
|
frame: a frame to detect current py-file runtime
|
||||||
|
detail(bool, default True): whether the function shows the detail gpu memory usage
|
||||||
|
path(str): where to save log file
|
||||||
|
verbose(bool, default False): whether show the trivial exception
|
||||||
|
device(int): GPU number, default is 0
|
||||||
|
"""
|
||||||
|
def __init__(self, frame, detail=True, path='', verbose=False, device=0):
|
||||||
|
self.frame = frame
|
||||||
|
self.print_detail = detail
|
||||||
|
self.last_tensor_sizes = set()
|
||||||
|
self.gpu_profile_fn = path + f'{datetime.datetime.now():%d-%b-%y-%H.%M.%S}-gpu_mem_track.txt'
|
||||||
|
self.verbose = verbose
|
||||||
|
self.begin = True
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.func_name = frame.f_code.co_name
|
||||||
|
self.filename = frame.f_globals["__file__"]
|
||||||
|
if (self.filename.endswith(".pyc") or
|
||||||
|
self.filename.endswith(".pyo")):
|
||||||
|
self.filename = self.filename[:-1]
|
||||||
|
self.module_name = self.frame.f_globals["__name__"]
|
||||||
|
self.curr_line = self.frame.f_lineno
|
||||||
|
|
||||||
|
def get_tensors(self):
|
||||||
|
for obj in gc.get_objects():
|
||||||
|
try:
|
||||||
|
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
|
||||||
|
tensor = obj
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if tensor.is_cuda:
|
||||||
|
yield tensor
|
||||||
|
except Exception as e:
|
||||||
|
if self.verbose:
|
||||||
|
print('A trivial exception occured: {}'.format(e))
|
||||||
|
|
||||||
|
def track(self):
|
||||||
|
"""
|
||||||
|
Track the GPU memory usage
|
||||||
|
"""
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
handle = pynvml.nvmlDeviceGetHandleByIndex(self.device)
|
||||||
|
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||||
|
self.curr_line = self.frame.f_lineno
|
||||||
|
where_str = self.module_name + ' ' + self.func_name + ':' + ' line ' + str(self.curr_line)
|
||||||
|
|
||||||
|
with open(self.gpu_profile_fn, 'a+') as f:
|
||||||
|
|
||||||
|
if self.begin:
|
||||||
|
f.write(f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |"
|
||||||
|
f" Total Used Memory:{meminfo.used/1000**2:<7.1f}Mb\n\n")
|
||||||
|
self.begin = False
|
||||||
|
|
||||||
|
if self.print_detail is True:
|
||||||
|
ts_list = [tensor.size() for tensor in self.get_tensors()]
|
||||||
|
new_tensor_sizes = {(type(x), tuple(x.size()), ts_list.count(x.size()), np.prod(np.array(x.size()))*4/1000**2)
|
||||||
|
for x in self.get_tensors()}
|
||||||
|
for t, s, n, m in new_tensor_sizes - self.last_tensor_sizes:
|
||||||
|
f.write(f'+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20}\n')
|
||||||
|
for t, s, n, m in self.last_tensor_sizes - new_tensor_sizes:
|
||||||
|
f.write(f'- | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} \n')
|
||||||
|
self.last_tensor_sizes = new_tensor_sizes
|
||||||
|
|
||||||
|
f.write(f"\nAt {where_str:<50}"
|
||||||
|
f"Total Used Memory:{meminfo.used/1000**2:<7.1f}Mb\n\n")
|
||||||
|
|
||||||
|
pynvml.nvmlShutdown()
|
Loading…
Reference in New Issue
Block a user