2020-10-07 02:38:38 +00:00
|
|
|
import warnings
|
|
|
|
|
2023-03-21 15:39:28 +00:00
|
|
|
import torch
|
|
|
|
|
2020-10-07 02:38:38 +00:00
|
|
|
|
|
|
|
def detach_variable(inputs):
|
|
|
|
if isinstance(inputs, tuple):
|
|
|
|
out = []
|
|
|
|
for inp in inputs:
|
|
|
|
x = inp.detach()
|
|
|
|
x.requires_grad = inp.requires_grad
|
|
|
|
out.append(x)
|
|
|
|
return tuple(out)
|
|
|
|
else:
|
|
|
|
raise RuntimeError(
|
|
|
|
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def check_backward_validity(inputs):
|
|
|
|
if not any(inp.requires_grad for inp in inputs):
|
2023-03-21 15:39:28 +00:00
|
|
|
warnings.warn(
|
|
|
|
"None of the inputs have requires_grad=True. Gradients will be None")
|
2020-10-07 02:38:38 +00:00
|
|
|
|
|
|
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, run_function, length, *args):
|
|
|
|
ctx.run_function = run_function
|
|
|
|
ctx.input_tensors = list(args[:length])
|
|
|
|
ctx.input_params = list(args[length:])
|
|
|
|
with torch.no_grad():
|
|
|
|
output_tensors = ctx.run_function(*ctx.input_tensors)
|
|
|
|
return output_tensors
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, *output_grads):
|
|
|
|
for i in range(len(ctx.input_tensors)):
|
|
|
|
temp = ctx.input_tensors[i]
|
|
|
|
ctx.input_tensors[i] = temp.detach()
|
|
|
|
ctx.input_tensors[i].requires_grad = temp.requires_grad
|
|
|
|
with torch.enable_grad():
|
|
|
|
output_tensors = ctx.run_function(*ctx.input_tensors)
|
2023-03-21 15:39:28 +00:00
|
|
|
input_grads = torch.autograd.grad(
|
|
|
|
output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)
|
2020-10-07 02:38:38 +00:00
|
|
|
return (None, None) + input_grads
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint(module, *params):
|
2023-03-21 15:39:28 +00:00
|
|
|
differentiable_params = tuple(
|
|
|
|
filter(lambda p: p.requires_grad, module.parameters()))
|
2020-10-07 02:38:38 +00:00
|
|
|
if len(differentiable_params) > 0:
|
|
|
|
args = params + differentiable_params
|
|
|
|
return CheckpointFunction.apply(module, len(params), *args)
|
|
|
|
else:
|
2023-03-21 15:39:28 +00:00
|
|
|
return module(*params)
|