Checkpoint RRDB
Greatly reduces memory consumption with a low performance penalty
This commit is contained in:
parent
8580490a85
commit
bfdfaab911
|
@ -5,6 +5,7 @@ import torch.nn.functional as F
|
||||||
import models.archs.arch_util as arch_util
|
import models.archs.arch_util as arch_util
|
||||||
from models.archs.arch_util import PixelUnshuffle
|
from models.archs.arch_util import PixelUnshuffle
|
||||||
import torchvision
|
import torchvision
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock_5C(nn.Module):
|
class ResidualDenseBlock_5C(nn.Module):
|
||||||
|
@ -41,9 +42,9 @@ class RRDB(nn.Module):
|
||||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.RDB1(x)
|
out = checkpoint(self.RDB1, x)
|
||||||
out = self.RDB2(out)
|
out = checkpoint(self.RDB2, out)
|
||||||
out = self.RDB3(out)
|
out = checkpoint(self.RDB3, out)
|
||||||
return out * 0.2 + x
|
return out * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user