Add distributed_checkpoint for more efficient checkpoints
This commit is contained in:
parent
e4b89a172f
commit
4290918359
|
@ -105,6 +105,7 @@ if __name__ == "__main__":
|
|||
screen=True, tofile=True)
|
||||
logger = logging.getLogger('base')
|
||||
logger.info(option.dict2str(opt))
|
||||
util.loaded_options = opt
|
||||
|
||||
#### Create test dataset and dataloader
|
||||
test_loaders = []
|
||||
|
|
51
codes/utils/distributed_checkpont.py
Normal file
51
codes/utils/distributed_checkpont.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
import torch
|
||||
import warnings
|
||||
|
||||
|
||||
def detach_variable(inputs):
|
||||
if isinstance(inputs, tuple):
|
||||
out = []
|
||||
for inp in inputs:
|
||||
x = inp.detach()
|
||||
x.requires_grad = inp.requires_grad
|
||||
out.append(x)
|
||||
return tuple(out)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
|
||||
|
||||
|
||||
def check_backward_validity(inputs):
|
||||
if not any(inp.requires_grad for inp in inputs):
|
||||
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
for i in range(len(ctx.input_tensors)):
|
||||
temp = ctx.input_tensors[i]
|
||||
ctx.input_tensors[i] = temp.detach()
|
||||
ctx.input_tensors[i].requires_grad = temp.requires_grad
|
||||
with torch.enable_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def checkpoint(module, *params):
|
||||
differentiable_params = tuple(filter(lambda p: p.requires_grad, module.parameters()))
|
||||
if len(differentiable_params) > 0:
|
||||
args = params + differentiable_params
|
||||
return CheckpointFunction.apply(module, len(params), *args)
|
||||
else:
|
||||
return module(*params)
|
Loading…
Reference in New Issue
Block a user