forked from mrq/DL-Art-School
Update LR layers to checkpoint better
This commit is contained in:
parent
ce929a6b3f
commit
8ada52ccdc
|
@ -8,6 +8,8 @@ import torch.nn.functional as F
|
|||
import torch.nn.utils.spectral_norm as SpectralNorm
|
||||
from math import sqrt
|
||||
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
@ -211,24 +213,6 @@ def normalization(channels):
|
|||
return GroupNorm32(groups, channels)
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
|
@ -506,11 +490,14 @@ class AttentionBlock(nn.Module):
|
|||
|
||||
def forward(self, x, mask=None):
|
||||
if self.do_checkpoint:
|
||||
return checkpoint(self._forward, x, mask)
|
||||
if mask is not None:
|
||||
return checkpoint(self._forward, x, mask)
|
||||
else:
|
||||
return checkpoint(self._forward, x)
|
||||
else:
|
||||
return self._forward(x, mask)
|
||||
|
||||
def _forward(self, x, mask):
|
||||
def _forward(self, x, mask=None):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
Subproject commit db2b7899ea8506e90418dbd389300c49bdbb55c3
|
Loading…
Reference in New Issue
Block a user