Checkpoint RRDB

Greatly reduces memory consumption with a low performance penalty
This commit is contained in:
James Betker 2020-09-04 15:32:00 -06:00
parent 8580490a85
commit bfdfaab911

View File

@ -5,6 +5,7 @@ import torch.nn.functional as F
import models.archs.arch_util as arch_util
from models.archs.arch_util import PixelUnshuffle
import torchvision
from torch.utils.checkpoint import checkpoint
class ResidualDenseBlock_5C(nn.Module):
@ -41,9 +42,9 @@ class RRDB(nn.Module):
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
out = checkpoint(self.RDB1, x)
out = checkpoint(self.RDB2, out)
out = checkpoint(self.RDB3, out)
return out * 0.2 + x