From 42909183594db7d74b6c789b930ab9ef33477000 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 6 Oct 2020 20:38:38 -0600 Subject: [PATCH] Add distributed_checkpoint for more efficient checkpoints --- codes/process_video.py | 1 + codes/utils/distributed_checkpont.py | 51 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 codes/utils/distributed_checkpont.py diff --git a/codes/process_video.py b/codes/process_video.py index e8bea43c..bc91d317 100644 --- a/codes/process_video.py +++ b/codes/process_video.py @@ -105,6 +105,7 @@ if __name__ == "__main__": screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) + util.loaded_options = opt #### Create test dataset and dataloader test_loaders = [] diff --git a/codes/utils/distributed_checkpont.py b/codes/utils/distributed_checkpont.py new file mode 100644 index 00000000..33e8260d --- /dev/null +++ b/codes/utils/distributed_checkpont.py @@ -0,0 +1,51 @@ +import torch +import warnings + + +def detach_variable(inputs): + if isinstance(inputs, tuple): + out = [] + for inp in inputs: + x = inp.detach() + x.requires_grad = inp.requires_grad + out.append(x) + return tuple(out) + else: + raise RuntimeError( + "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) + + +def check_backward_validity(inputs): + if not any(inp.requires_grad for inp in inputs): + warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + for i in range(len(ctx.input_tensors)): + temp = ctx.input_tensors[i] + ctx.input_tensors[i] = temp.detach() + ctx.input_tensors[i].requires_grad = temp.requires_grad + with torch.enable_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) + return (None, None) + input_grads + + +def checkpoint(module, *params): + differentiable_params = tuple(filter(lambda p: p.requires_grad, module.parameters())) + if len(differentiable_params) > 0: + args = params + differentiable_params + return CheckpointFunction.apply(module, len(params), *args) + else: + return module(*params) \ No newline at end of file