From dc9ff8e05bdb69c084b6b733fbb640dc26312166 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Thu, 3 Dec 2020 23:41:57 -0700
Subject: [PATCH] Allow the majority of the srflow steps to checkpoint

---
 codes/models/archs/srflow_orig/FlowStep.py         | 2 +-
 codes/models/archs/srflow_orig/FlowUpsamplerNet.py | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/codes/models/archs/srflow_orig/FlowStep.py b/codes/models/archs/srflow_orig/FlowStep.py
index 1c128c92..af68dae9 100644
--- a/codes/models/archs/srflow_orig/FlowStep.py
+++ b/codes/models/archs/srflow_orig/FlowStep.py
@@ -62,7 +62,7 @@ class FlowStep(nn.Module):
         else:
             raise RuntimeError("coupling not Found:", flow_coupling)
 
-    def forward(self, input, logdet=None, reverse=False, rrdbResults=None):
+    def forward(self, input, logdet=None, rrdbResults=None, reverse=False):
         if not reverse:
             return self.normal_flow(input, logdet, rrdbResults)
         else:
diff --git a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py
index 0b595128..3282e702 100644
--- a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py
+++ b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py
@@ -7,7 +7,7 @@ from models.archs.srflow_orig import flow, thops
 from models.archs.srflow_orig.Split import Split2d
 from models.archs.srflow_orig.glow_arch import f_conv2d_bias
 from models.archs.srflow_orig.FlowStep import FlowStep
-from utils.util import opt_get
+from utils.util import opt_get, checkpoint
 
 
 class FlowUpsamplerNet(nn.Module):
@@ -219,7 +219,7 @@ class FlowUpsamplerNet(nn.Module):
             level_conditionals[level] = rrdbResults[self.levelToName[level]]
 
             if isinstance(layer, FlowStep):
-                fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level])
+                fl_fea, logdet = checkpoint(layer, fl_fea, logdet, level_conditionals[level])
             elif isinstance(layer, Split2d):
                 fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
                                                       y_onehot=y_onehot)