From 93a330281961335b3cd899e3983e2336cf9b2726 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 4 Mar 2022 17:57:33 -0700 Subject: [PATCH] Push training_state data to CPU memory before saving it For whatever reason, keeping this on GPU memory just doesn't work. When you load it, it consumes a large amount of GPU memory and that utilization doesn't go away. Saving to CPU should fix this. --- codes/trainer/base_model.py | 4 ++-- codes/utils/util.py | 18 +++++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index 9ca2fa9b..f1c38b63 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -6,7 +6,7 @@ from torch.distributed.optim import ZeroRedundancyOptimizer from torch.nn.parallel.distributed import DistributedDataParallel import utils.util -from utils.util import opt_get, optimizer_to +from utils.util import opt_get, optimizer_to, map_to_device class BaseModel(): @@ -148,7 +148,7 @@ class BaseModel(): state['amp'] = amp.state_dict() save_filename = '{}.state'.format(utils.util.opt_get(state, ['iter'], 'no_step_provided')) save_path = os.path.join(self.opt['path']['training_state'], save_filename) - torch.save(state, save_path) + torch.save(map_to_device(state, 'cpu'), save_path) if '__state__' not in self.save_history.keys(): self.save_history['__state__'] = [] self.save_history['__state__'].append(save_path) diff --git a/codes/utils/util.py b/codes/utils/util.py index bbc97772..a64ffbb3 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -499,6 +499,22 @@ def map_cuda_to_correct_device(storage, loc): else: return storage.cpu() +def list_to_device(l, dev): + return [anything_to_device(e, dev) for e in l] + +def map_to_device(m, dev): + return {k: anything_to_device(v, dev) for k,v in m.items()} + +def anything_to_device(obj, dev): + if isinstance(obj, list): + return list_to_device(obj, dev) + elif isinstance(obj, map): + return map_to_device(obj, dev) + elif isinstance(obj, torch.Tensor): + return obj.to(dev) + else: + return obj + def ceil_multiple(base, multiple): """ @@ -524,4 +540,4 @@ def optimizer_to(opt, device): if isinstance(subparam, torch.Tensor): subparam.data = subparam.data.to(device) if subparam._grad is not None: - subparam._grad.data = subparam._grad.data.to(device) \ No newline at end of file + subparam._grad.data = subparam._grad.data.to(device)