Make dalle transformer checkpointable
This commit is contained in:
parent
70b17da193
commit
7de3874f15
|
@ -1,3 +1,5 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from operator import itemgetter
|
||||
|
@ -5,6 +7,9 @@ from torch.autograd.function import Function
|
|||
from torch.utils.checkpoint import get_device_states, set_device_states
|
||||
|
||||
# for routing arguments into the functions of the reversible layer
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
def route_args(router, args, depth):
|
||||
routed_args = [(dict(), dict()) for _ in range(depth)]
|
||||
matched_keys = [key for key in args.keys() if key in router]
|
||||
|
@ -136,8 +141,8 @@ class SequentialSequence(nn.Module):
|
|||
layers_and_args = list(zip(self.layers, args))
|
||||
|
||||
for (f, g), (f_args, g_args) in layers_and_args:
|
||||
x = x + f(x, **f_args)
|
||||
x = x + g(x, **g_args)
|
||||
x = x + checkpoint(functools.partial(f, **f_args), x)
|
||||
x = x + checkpoint(functools.partial(g, **g_args), x)
|
||||
return x
|
||||
|
||||
class ReversibleSequence(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue
Block a user