diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index ba492b70..65bcc569 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -103,15 +103,6 @@ class RRDBTrunk(nn.Module): fea = fea + trunk return fea - def get_debug_values(self, step, prefix): - val = {} - i = 0 - for block in self.RRDB_trunk._modules.values(): - if hasattr(block, "get_debug_values"): - val.update(block.get_debug_values(step, "%s_rdb_%i" % (prefix, i))) - i += 1 - return val - # Adds some base methods that all RRDB* classes will use. class RRDBBase(nn.Module): @@ -125,14 +116,6 @@ class RRDBBase(nn.Module): for layer in trunk.rrdb_layers: layer.set_temperature(temp) - def get_debug_values(self, step): - val = {} - for i, trunk in enumerate(self.trunks): - for j, block in enumerate(trunk.RRDB_trunk._modules.values()): - if hasattr(block, "get_debug_values"): - val.update(block.get_debug_values(step, "trunk_%i_block_%i" % (i, j))) - return val - # This class uses a RRDBTrunk to perform processing on an image, then upsamples it. class RRDBNet(RRDBBase): @@ -164,7 +147,7 @@ class RRDBNet(RRDBBase): fea = self.lrelu(self.upconv2(fea)) out = self.conv_last(self.lrelu(self.HRconv(fea))) - return (out,) + return out def load_state_dict(self, state_dict, strict=True): # The parameters in self.trunk used to be in this class. To support loading legacy saves, restore them.