Allow the majority of the srflow steps to checkpoint
This commit is contained in:
parent
06d1c62c5a
commit
dc9ff8e05b
|
@ -62,7 +62,7 @@ class FlowStep(nn.Module):
|
|||
else:
|
||||
raise RuntimeError("coupling not Found:", flow_coupling)
|
||||
|
||||
def forward(self, input, logdet=None, reverse=False, rrdbResults=None):
|
||||
def forward(self, input, logdet=None, rrdbResults=None, reverse=False):
|
||||
if not reverse:
|
||||
return self.normal_flow(input, logdet, rrdbResults)
|
||||
else:
|
||||
|
|
|
@ -7,7 +7,7 @@ from models.archs.srflow_orig import flow, thops
|
|||
from models.archs.srflow_orig.Split import Split2d
|
||||
from models.archs.srflow_orig.glow_arch import f_conv2d_bias
|
||||
from models.archs.srflow_orig.FlowStep import FlowStep
|
||||
from utils.util import opt_get
|
||||
from utils.util import opt_get, checkpoint
|
||||
|
||||
|
||||
class FlowUpsamplerNet(nn.Module):
|
||||
|
@ -219,7 +219,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
||||
|
||||
if isinstance(layer, FlowStep):
|
||||
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level])
|
||||
fl_fea, logdet = checkpoint(layer, fl_fea, logdet, level_conditionals[level])
|
||||
elif isinstance(layer, Split2d):
|
||||
fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
|
||||
y_onehot=y_onehot)
|
||||
|
|
Loading…
Reference in New Issue
Block a user