From 7de3874f15c09421716f9f709aecb049c4f6dedf Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 9 Jan 2022 19:14:35 -0700 Subject: [PATCH] Make dalle transformer checkpointable --- codes/models/lucidrains/dalle/reversible.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/codes/models/lucidrains/dalle/reversible.py b/codes/models/lucidrains/dalle/reversible.py index a235323a..c55416ae 100644 --- a/codes/models/lucidrains/dalle/reversible.py +++ b/codes/models/lucidrains/dalle/reversible.py @@ -1,3 +1,5 @@ +import functools + import torch import torch.nn as nn from operator import itemgetter @@ -5,6 +7,9 @@ from torch.autograd.function import Function from torch.utils.checkpoint import get_device_states, set_device_states # for routing arguments into the functions of the reversible layer +from utils.util import checkpoint + + def route_args(router, args, depth): routed_args = [(dict(), dict()) for _ in range(depth)] matched_keys = [key for key in args.keys() if key in router] @@ -136,8 +141,8 @@ class SequentialSequence(nn.Module): layers_and_args = list(zip(self.layers, args)) for (f, g), (f_args, g_args) in layers_and_args: - x = x + f(x, **f_args) - x = x + g(x, **g_args) + x = x + checkpoint(functools.partial(f, **f_args), x) + x = x + checkpoint(functools.partial(g, **g_args), x) return x class ReversibleSequence(nn.Module):