From bfdfaab91190a86ccaaaa4dc4d87d5debb186e4f Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 4 Sep 2020 15:32:00 -0600 Subject: [PATCH] Checkpoint RRDB Greatly reduces memory consumption with a low performance penalty --- codes/models/archs/RRDBNet_arch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index f5ca03e3..ba492b70 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -5,6 +5,7 @@ import torch.nn.functional as F import models.archs.arch_util as arch_util from models.archs.arch_util import PixelUnshuffle import torchvision +from torch.utils.checkpoint import checkpoint class ResidualDenseBlock_5C(nn.Module): @@ -41,9 +42,9 @@ class RRDB(nn.Module): self.RDB3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) + out = checkpoint(self.RDB1, x) + out = checkpoint(self.RDB2, out) + out = checkpoint(self.RDB3, out) return out * 0.2 + x