Allow the majority of the srflow steps to checkpoint

This commit is contained in:
James Betker 2020-12-03 23:41:57 -07:00
parent 06d1c62c5a
commit dc9ff8e05b
2 changed files with 3 additions and 3 deletions

View File

@ -62,7 +62,7 @@ class FlowStep(nn.Module):
else:
raise RuntimeError("coupling not Found:", flow_coupling)
def forward(self, input, logdet=None, reverse=False, rrdbResults=None):
def forward(self, input, logdet=None, rrdbResults=None, reverse=False):
if not reverse:
return self.normal_flow(input, logdet, rrdbResults)
else:

View File

@ -7,7 +7,7 @@ from models.archs.srflow_orig import flow, thops
from models.archs.srflow_orig.Split import Split2d
from models.archs.srflow_orig.glow_arch import f_conv2d_bias
from models.archs.srflow_orig.FlowStep import FlowStep
from utils.util import opt_get
from utils.util import opt_get, checkpoint
class FlowUpsamplerNet(nn.Module):
@ -219,7 +219,7 @@ class FlowUpsamplerNet(nn.Module):
level_conditionals[level] = rrdbResults[self.levelToName[level]]
if isinstance(layer, FlowStep):
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level])
fl_fea, logdet = checkpoint(layer, fl_fea, logdet, level_conditionals[level])
elif isinstance(layer, Split2d):
fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
y_onehot=y_onehot)