Get rid of get_debug_values from RRDB, rectify outputs
This commit is contained in:
parent
e0bd68efda
commit
b83f097082
|
@ -103,15 +103,6 @@ class RRDBTrunk(nn.Module):
|
|||
fea = fea + trunk
|
||||
return fea
|
||||
|
||||
def get_debug_values(self, step, prefix):
|
||||
val = {}
|
||||
i = 0
|
||||
for block in self.RRDB_trunk._modules.values():
|
||||
if hasattr(block, "get_debug_values"):
|
||||
val.update(block.get_debug_values(step, "%s_rdb_%i" % (prefix, i)))
|
||||
i += 1
|
||||
return val
|
||||
|
||||
|
||||
# Adds some base methods that all RRDB* classes will use.
|
||||
class RRDBBase(nn.Module):
|
||||
|
@ -125,14 +116,6 @@ class RRDBBase(nn.Module):
|
|||
for layer in trunk.rrdb_layers:
|
||||
layer.set_temperature(temp)
|
||||
|
||||
def get_debug_values(self, step):
|
||||
val = {}
|
||||
for i, trunk in enumerate(self.trunks):
|
||||
for j, block in enumerate(trunk.RRDB_trunk._modules.values()):
|
||||
if hasattr(block, "get_debug_values"):
|
||||
val.update(block.get_debug_values(step, "trunk_%i_block_%i" % (i, j)))
|
||||
return val
|
||||
|
||||
|
||||
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
|
||||
class RRDBNet(RRDBBase):
|
||||
|
@ -164,7 +147,7 @@ class RRDBNet(RRDBBase):
|
|||
fea = self.lrelu(self.upconv2(fea))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return (out,)
|
||||
return out
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# The parameters in self.trunk used to be in this class. To support loading legacy saves, restore them.
|
||||
|
|
Loading…
Reference in New Issue
Block a user