diff --git a/codes/models/archs/srflow_orig/FlowStep.py b/codes/models/archs/srflow_orig/FlowStep.py index 1c128c92..af68dae9 100644 --- a/codes/models/archs/srflow_orig/FlowStep.py +++ b/codes/models/archs/srflow_orig/FlowStep.py @@ -62,7 +62,7 @@ class FlowStep(nn.Module): else: raise RuntimeError("coupling not Found:", flow_coupling) - def forward(self, input, logdet=None, reverse=False, rrdbResults=None): + def forward(self, input, logdet=None, rrdbResults=None, reverse=False): if not reverse: return self.normal_flow(input, logdet, rrdbResults) else: diff --git a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py index 0b595128..3282e702 100644 --- a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py +++ b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py @@ -7,7 +7,7 @@ from models.archs.srflow_orig import flow, thops from models.archs.srflow_orig.Split import Split2d from models.archs.srflow_orig.glow_arch import f_conv2d_bias from models.archs.srflow_orig.FlowStep import FlowStep -from utils.util import opt_get +from utils.util import opt_get, checkpoint class FlowUpsamplerNet(nn.Module): @@ -219,7 +219,7 @@ class FlowUpsamplerNet(nn.Module): level_conditionals[level] = rrdbResults[self.levelToName[level]] if isinstance(layer, FlowStep): - fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level]) + fl_fea, logdet = checkpoint(layer, fl_fea, logdet, level_conditionals[level]) elif isinstance(layer, Split2d): fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level], y_onehot=y_onehot)