Make dalle transformer checkpointable
This commit is contained in:
parent
70b17da193
commit
7de3874f15
|
@ -1,3 +1,5 @@
|
||||||
|
import functools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
|
@ -5,6 +7,9 @@ from torch.autograd.function import Function
|
||||||
from torch.utils.checkpoint import get_device_states, set_device_states
|
from torch.utils.checkpoint import get_device_states, set_device_states
|
||||||
|
|
||||||
# for routing arguments into the functions of the reversible layer
|
# for routing arguments into the functions of the reversible layer
|
||||||
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
def route_args(router, args, depth):
|
def route_args(router, args, depth):
|
||||||
routed_args = [(dict(), dict()) for _ in range(depth)]
|
routed_args = [(dict(), dict()) for _ in range(depth)]
|
||||||
matched_keys = [key for key in args.keys() if key in router]
|
matched_keys = [key for key in args.keys() if key in router]
|
||||||
|
@ -136,8 +141,8 @@ class SequentialSequence(nn.Module):
|
||||||
layers_and_args = list(zip(self.layers, args))
|
layers_and_args = list(zip(self.layers, args))
|
||||||
|
|
||||||
for (f, g), (f_args, g_args) in layers_and_args:
|
for (f, g), (f_args, g_args) in layers_and_args:
|
||||||
x = x + f(x, **f_args)
|
x = x + checkpoint(functools.partial(f, **f_args), x)
|
||||||
x = x + g(x, **g_args)
|
x = x + checkpoint(functools.partial(g, **g_args), x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class ReversibleSequence(nn.Module):
|
class ReversibleSequence(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user