diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index f5ca03e3..ba492b70 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -5,6 +5,7 @@ import torch.nn.functional as F import models.archs.arch_util as arch_util from models.archs.arch_util import PixelUnshuffle import torchvision +from torch.utils.checkpoint import checkpoint class ResidualDenseBlock_5C(nn.Module): @@ -41,9 +42,9 @@ class RRDB(nn.Module): self.RDB3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) + out = checkpoint(self.RDB1, x) + out = checkpoint(self.RDB2, out) + out = checkpoint(self.RDB3, out) return out * 0.2 + x