forked from mrq/DL-Art-School
Update LR layers to checkpoint better
This commit is contained in:
parent
ce929a6b3f
commit
8ada52ccdc
|
@ -8,6 +8,8 @@ import torch.nn.functional as F
|
||||||
import torch.nn.utils.spectral_norm as SpectralNorm
|
import torch.nn.utils.spectral_norm as SpectralNorm
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
|
|
||||||
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
@ -211,24 +213,6 @@ def normalization(channels):
|
||||||
return GroupNorm32(groups, channels)
|
return GroupNorm32(groups, channels)
|
||||||
|
|
||||||
|
|
||||||
def checkpoint(func, inputs, params, flag):
|
|
||||||
"""
|
|
||||||
Evaluate a function without caching intermediate activations, allowing for
|
|
||||||
reduced memory at the expense of extra compute in the backward pass.
|
|
||||||
|
|
||||||
:param func: the function to evaluate.
|
|
||||||
:param inputs: the argument sequence to pass to `func`.
|
|
||||||
:param params: a sequence of parameters `func` depends on but does not
|
|
||||||
explicitly take as arguments.
|
|
||||||
:param flag: if False, disable gradient checkpointing.
|
|
||||||
"""
|
|
||||||
if flag:
|
|
||||||
args = tuple(inputs) + tuple(params)
|
|
||||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
|
||||||
else:
|
|
||||||
return func(*inputs)
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionPool2d(nn.Module):
|
class AttentionPool2d(nn.Module):
|
||||||
"""
|
"""
|
||||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||||
|
@ -506,11 +490,14 @@ class AttentionBlock(nn.Module):
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
def forward(self, x, mask=None):
|
||||||
if self.do_checkpoint:
|
if self.do_checkpoint:
|
||||||
return checkpoint(self._forward, x, mask)
|
if mask is not None:
|
||||||
|
return checkpoint(self._forward, x, mask)
|
||||||
|
else:
|
||||||
|
return checkpoint(self._forward, x)
|
||||||
else:
|
else:
|
||||||
return self._forward(x, mask)
|
return self._forward(x, mask)
|
||||||
|
|
||||||
def _forward(self, x, mask):
|
def _forward(self, x, mask=None):
|
||||||
b, c, *spatial = x.shape
|
b, c, *spatial = x.shape
|
||||||
x = x.reshape(b, c, -1)
|
x = x.reshape(b, c, -1)
|
||||||
qkv = self.qkv(self.norm(x))
|
qkv = self.qkv(self.norm(x))
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
Subproject commit db2b7899ea8506e90418dbd389300c49bdbb55c3
|
|
Loading…
Reference in New Issue
Block a user