Make dalle transformer checkpointable

This commit is contained in:
James Betker 2022-01-09 19:14:35 -07:00
parent 70b17da193
commit 7de3874f15

View File

@ -1,3 +1,5 @@
import functools
import torch import torch
import torch.nn as nn import torch.nn as nn
from operator import itemgetter from operator import itemgetter
@ -5,6 +7,9 @@ from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states from torch.utils.checkpoint import get_device_states, set_device_states
# for routing arguments into the functions of the reversible layer # for routing arguments into the functions of the reversible layer
from utils.util import checkpoint
def route_args(router, args, depth): def route_args(router, args, depth):
routed_args = [(dict(), dict()) for _ in range(depth)] routed_args = [(dict(), dict()) for _ in range(depth)]
matched_keys = [key for key in args.keys() if key in router] matched_keys = [key for key in args.keys() if key in router]
@ -136,8 +141,8 @@ class SequentialSequence(nn.Module):
layers_and_args = list(zip(self.layers, args)) layers_and_args = list(zip(self.layers, args))
for (f, g), (f_args, g_args) in layers_and_args: for (f, g), (f_args, g_args) in layers_and_args:
x = x + f(x, **f_args) x = x + checkpoint(functools.partial(f, **f_args), x)
x = x + g(x, **g_args) x = x + checkpoint(functools.partial(g, **g_args), x)
return x return x
class ReversibleSequence(nn.Module): class ReversibleSequence(nn.Module):