import torch import warnings def detach_variable(inputs): if isinstance(inputs, tuple): out = [] for inp in inputs: x = inp.detach() x.requires_grad = inp.requires_grad out.append(x) return tuple(out) else: raise RuntimeError( "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) def check_backward_validity(inputs): if not any(inp.requires_grad for inp in inputs): warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.input_tensors = list(args[:length]) ctx.input_params = list(args[length:]) with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) return output_tensors @staticmethod def backward(ctx, *output_grads): for i in range(len(ctx.input_tensors)): temp = ctx.input_tensors[i] ctx.input_tensors[i] = temp.detach() ctx.input_tensors[i].requires_grad = temp.requires_grad with torch.enable_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) return (None, None) + input_grads def checkpoint(module, *params): differentiable_params = tuple(filter(lambda p: p.requires_grad, module.parameters())) if len(differentiable_params) > 0: args = params + differentiable_params return CheckpointFunction.apply(module, len(params), *args) else: return module(*params)