Get rid of get_debug_values from RRDB, rectify outputs

This commit is contained in:
James Betker 2020-09-19 21:46:36 -06:00
parent e0bd68efda
commit b83f097082

View File

@ -103,15 +103,6 @@ class RRDBTrunk(nn.Module):
fea = fea + trunk fea = fea + trunk
return fea return fea
def get_debug_values(self, step, prefix):
val = {}
i = 0
for block in self.RRDB_trunk._modules.values():
if hasattr(block, "get_debug_values"):
val.update(block.get_debug_values(step, "%s_rdb_%i" % (prefix, i)))
i += 1
return val
# Adds some base methods that all RRDB* classes will use. # Adds some base methods that all RRDB* classes will use.
class RRDBBase(nn.Module): class RRDBBase(nn.Module):
@ -125,14 +116,6 @@ class RRDBBase(nn.Module):
for layer in trunk.rrdb_layers: for layer in trunk.rrdb_layers:
layer.set_temperature(temp) layer.set_temperature(temp)
def get_debug_values(self, step):
val = {}
for i, trunk in enumerate(self.trunks):
for j, block in enumerate(trunk.RRDB_trunk._modules.values()):
if hasattr(block, "get_debug_values"):
val.update(block.get_debug_values(step, "trunk_%i_block_%i" % (i, j)))
return val
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it. # This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
class RRDBNet(RRDBBase): class RRDBNet(RRDBBase):
@ -164,7 +147,7 @@ class RRDBNet(RRDBBase):
fea = self.lrelu(self.upconv2(fea)) fea = self.lrelu(self.upconv2(fea))
out = self.conv_last(self.lrelu(self.HRconv(fea))) out = self.conv_last(self.lrelu(self.HRconv(fea)))
return (out,) return out
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
# The parameters in self.trunk used to be in this class. To support loading legacy saves, restore them. # The parameters in self.trunk used to be in this class. To support loading legacy saves, restore them.