forked from mrq/DL-Art-School
Allow the majority of the srflow steps to checkpoint
This commit is contained in:
parent
06d1c62c5a
commit
dc9ff8e05b
|
@ -62,7 +62,7 @@ class FlowStep(nn.Module):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("coupling not Found:", flow_coupling)
|
raise RuntimeError("coupling not Found:", flow_coupling)
|
||||||
|
|
||||||
def forward(self, input, logdet=None, reverse=False, rrdbResults=None):
|
def forward(self, input, logdet=None, rrdbResults=None, reverse=False):
|
||||||
if not reverse:
|
if not reverse:
|
||||||
return self.normal_flow(input, logdet, rrdbResults)
|
return self.normal_flow(input, logdet, rrdbResults)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -7,7 +7,7 @@ from models.archs.srflow_orig import flow, thops
|
||||||
from models.archs.srflow_orig.Split import Split2d
|
from models.archs.srflow_orig.Split import Split2d
|
||||||
from models.archs.srflow_orig.glow_arch import f_conv2d_bias
|
from models.archs.srflow_orig.glow_arch import f_conv2d_bias
|
||||||
from models.archs.srflow_orig.FlowStep import FlowStep
|
from models.archs.srflow_orig.FlowStep import FlowStep
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get, checkpoint
|
||||||
|
|
||||||
|
|
||||||
class FlowUpsamplerNet(nn.Module):
|
class FlowUpsamplerNet(nn.Module):
|
||||||
|
@ -219,7 +219,7 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
||||||
|
|
||||||
if isinstance(layer, FlowStep):
|
if isinstance(layer, FlowStep):
|
||||||
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level])
|
fl_fea, logdet = checkpoint(layer, fl_fea, logdet, level_conditionals[level])
|
||||||
elif isinstance(layer, Split2d):
|
elif isinstance(layer, Split2d):
|
||||||
fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
|
fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
|
||||||
y_onehot=y_onehot)
|
y_onehot=y_onehot)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user