Checkpoint RRDB
Greatly reduces memory consumption with a low performance penalty
This commit is contained in:
parent
8580490a85
commit
bfdfaab911
|
@ -5,6 +5,7 @@ import torch.nn.functional as F
|
|||
import models.archs.arch_util as arch_util
|
||||
from models.archs.arch_util import PixelUnshuffle
|
||||
import torchvision
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
|
@ -41,9 +42,9 @@ class RRDB(nn.Module):
|
|||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
out = checkpoint(self.RDB1, x)
|
||||
out = checkpoint(self.RDB2, out)
|
||||
out = checkpoint(self.RDB3, out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user