Update LR layers to checkpoint better

This commit is contained in:
James Betker 2022-01-22 08:22:57 -07:00
parent ce929a6b3f
commit 8ada52ccdc
2 changed files with 7 additions and 21 deletions

View File

@ -8,6 +8,8 @@ import torch.nn.functional as F
import torch.nn.utils.spectral_norm as SpectralNorm
from math import sqrt
from utils.util import checkpoint
def exists(val):
return val is not None
@ -211,24 +213,6 @@ def normalization(channels):
return GroupNorm32(groups, channels)
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
@ -506,11 +490,14 @@ class AttentionBlock(nn.Module):
def forward(self, x, mask=None):
if self.do_checkpoint:
return checkpoint(self._forward, x, mask)
if mask is not None:
return checkpoint(self._forward, x, mask)
else:
return checkpoint(self._forward, x)
else:
return self._forward(x, mask)
def _forward(self, x, mask):
def _forward(self, x, mask=None):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))

@ -1 +0,0 @@
Subproject commit db2b7899ea8506e90418dbd389300c49bdbb55c3