diff --git a/.gitignore b/.gitignore index ef150caa..3243cc98 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ results/* tb_logger/* datasets/* options/* +codes/*.txt .vscode *.html diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index f26f2aab..e3ce73eb 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -210,7 +210,13 @@ class SRGANModel(BaseModel): self.fake_GenOut = [] var_ref_skips = [] for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): + + #from utils import gpu_mem_track + #import inspect + #gpu_tracker = gpu_mem_track.MemTracker(inspect.currentframe()) + #gpu_tracker.track() fake_GenOut = self.netG(var_L) + #gpu_tracker.track() # Extract the image output. For generators that output skip-through connections, the master output is always # the first element of the tuple. diff --git a/codes/utils/gpu_mem_track.py b/codes/utils/gpu_mem_track.py new file mode 100644 index 00000000..c2b43946 --- /dev/null +++ b/codes/utils/gpu_mem_track.py @@ -0,0 +1,79 @@ +import gc +import datetime +import pynvml + +import torch +import numpy as np + + +class MemTracker(object): + """ + Class used to track pytorch memory usage + Arguments: + frame: a frame to detect current py-file runtime + detail(bool, default True): whether the function shows the detail gpu memory usage + path(str): where to save log file + verbose(bool, default False): whether show the trivial exception + device(int): GPU number, default is 0 + """ + def __init__(self, frame, detail=True, path='', verbose=False, device=0): + self.frame = frame + self.print_detail = detail + self.last_tensor_sizes = set() + self.gpu_profile_fn = path + f'{datetime.datetime.now():%d-%b-%y-%H.%M.%S}-gpu_mem_track.txt' + self.verbose = verbose + self.begin = True + self.device = device + + self.func_name = frame.f_code.co_name + self.filename = frame.f_globals["__file__"] + if (self.filename.endswith(".pyc") or + self.filename.endswith(".pyo")): + self.filename = self.filename[:-1] + self.module_name = self.frame.f_globals["__name__"] + self.curr_line = self.frame.f_lineno + + def get_tensors(self): + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): + tensor = obj + else: + continue + if tensor.is_cuda: + yield tensor + except Exception as e: + if self.verbose: + print('A trivial exception occured: {}'.format(e)) + + def track(self): + """ + Track the GPU memory usage + """ + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(self.device) + meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) + self.curr_line = self.frame.f_lineno + where_str = self.module_name + ' ' + self.func_name + ':' + ' line ' + str(self.curr_line) + + with open(self.gpu_profile_fn, 'a+') as f: + + if self.begin: + f.write(f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |" + f" Total Used Memory:{meminfo.used/1000**2:<7.1f}Mb\n\n") + self.begin = False + + if self.print_detail is True: + ts_list = [tensor.size() for tensor in self.get_tensors()] + new_tensor_sizes = {(type(x), tuple(x.size()), ts_list.count(x.size()), np.prod(np.array(x.size()))*4/1000**2) + for x in self.get_tensors()} + for t, s, n, m in new_tensor_sizes - self.last_tensor_sizes: + f.write(f'+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20}\n') + for t, s, n, m in self.last_tensor_sizes - new_tensor_sizes: + f.write(f'- | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} \n') + self.last_tensor_sizes = new_tensor_sizes + + f.write(f"\nAt {where_str:<50}" + f"Total Used Memory:{meminfo.used/1000**2:<7.1f}Mb\n\n") + + pynvml.nvmlShutdown()