forked from mrq/DL-Art-School
Push training_state data to CPU memory before saving it
For whatever reason, keeping this on GPU memory just doesn't work. When you load it, it consumes a large amount of GPU memory and that utilization doesn't go away. Saving to CPU should fix this.
This commit is contained in:
parent
6000580e2e
commit
93a3302819
|
@ -6,7 +6,7 @@ from torch.distributed.optim import ZeroRedundancyOptimizer
|
|||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import utils.util
|
||||
from utils.util import opt_get, optimizer_to
|
||||
from utils.util import opt_get, optimizer_to, map_to_device
|
||||
|
||||
|
||||
class BaseModel():
|
||||
|
@ -148,7 +148,7 @@ class BaseModel():
|
|||
state['amp'] = amp.state_dict()
|
||||
save_filename = '{}.state'.format(utils.util.opt_get(state, ['iter'], 'no_step_provided'))
|
||||
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
||||
torch.save(state, save_path)
|
||||
torch.save(map_to_device(state, 'cpu'), save_path)
|
||||
if '__state__' not in self.save_history.keys():
|
||||
self.save_history['__state__'] = []
|
||||
self.save_history['__state__'].append(save_path)
|
||||
|
|
|
@ -499,6 +499,22 @@ def map_cuda_to_correct_device(storage, loc):
|
|||
else:
|
||||
return storage.cpu()
|
||||
|
||||
def list_to_device(l, dev):
|
||||
return [anything_to_device(e, dev) for e in l]
|
||||
|
||||
def map_to_device(m, dev):
|
||||
return {k: anything_to_device(v, dev) for k,v in m.items()}
|
||||
|
||||
def anything_to_device(obj, dev):
|
||||
if isinstance(obj, list):
|
||||
return list_to_device(obj, dev)
|
||||
elif isinstance(obj, map):
|
||||
return map_to_device(obj, dev)
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
return obj.to(dev)
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def ceil_multiple(base, multiple):
|
||||
"""
|
||||
|
@ -524,4 +540,4 @@ def optimizer_to(opt, device):
|
|||
if isinstance(subparam, torch.Tensor):
|
||||
subparam.data = subparam.data.to(device)
|
||||
if subparam._grad is not None:
|
||||
subparam._grad.data = subparam._grad.data.to(device)
|
||||
subparam._grad.data = subparam._grad.data.to(device)
|
||||
|
|
Loading…
Reference in New Issue
Block a user