From 8ada52ccdc2e223acf22fb39e2bd99e460e25f09 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 22 Jan 2022 08:22:57 -0700 Subject: [PATCH] Update LR layers to checkpoint better --- codes/models/arch_util.py | 27 +++++++-------------------- codes/models/flownet2 | 1 - 2 files changed, 7 insertions(+), 21 deletions(-) delete mode 160000 codes/models/flownet2 diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index 227b5eff..159a012c 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -8,6 +8,8 @@ import torch.nn.functional as F import torch.nn.utils.spectral_norm as SpectralNorm from math import sqrt +from utils.util import checkpoint + def exists(val): return val is not None @@ -211,24 +213,6 @@ def normalization(channels): return GroupNorm32(groups, channels) -def checkpoint(func, inputs, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. - - :param func: the function to evaluate. - :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if flag: - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) - else: - return func(*inputs) - - class AttentionPool2d(nn.Module): """ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py @@ -506,11 +490,14 @@ class AttentionBlock(nn.Module): def forward(self, x, mask=None): if self.do_checkpoint: - return checkpoint(self._forward, x, mask) + if mask is not None: + return checkpoint(self._forward, x, mask) + else: + return checkpoint(self._forward, x) else: return self._forward(x, mask) - def _forward(self, x, mask): + def _forward(self, x, mask=None): b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) diff --git a/codes/models/flownet2 b/codes/models/flownet2 deleted file mode 160000 index db2b7899..00000000 --- a/codes/models/flownet2 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit db2b7899ea8506e90418dbd389300c49bdbb55c3