diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 88f705b8..f26f2aab 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -439,12 +439,17 @@ class SRGANModel(BaseModel): self.netG.train() # Fetches a summary of the log. - def get_current_log(self): + def get_current_log(self, step): return_log = {} for k in self.log_dict.keys(): if not isinstance(self.log_dict[k], list): continue return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k]) + + # Some generators can do their own metric logging. + if hasattr(self.netG.module, "get_debug_values"): + return_log.update(self.netG.module.get_debug_values(step)) + return return_log def get_current_visuals(self, need_GT=True): diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 2399e1c9..ad0f0297 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -61,16 +61,49 @@ class RRDB(nn.Module): return out * 0.2 + x class AttentiveRRDB(RRDB): - def __init__(self, nf, gc=32, num_convs=8, init_temperature=1): + counter = 0 + + def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1): super(RRDB, self).__init__() self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature) self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature) self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature) + self.init_temperature = init_temperature + self.final_temperature_step = final_temperature_step + self.running_mean = 0 + self.running_count = 0 + self.counter = AttentiveRRDB.counter + AttentiveRRDB.counter += 1 def set_temperature(self, temp): - self.RDB1.set_temperature(temp) - self.RDB2.set_temperature(temp) - self.RDB3.set_temperature(temp) + self.RDB1.switcher.set_attention_temperature(temp) + self.RDB2.switcher.set_attention_temperature(temp) + self.RDB3.switcher.set_attention_temperature(temp) + + def forward(self, x): + out, att1 = self.RDB1(x, True) + out, att2 = self.RDB2(out, True) + out, att3 = self.RDB3(out, True) + + a1mean, _ = switched_conv.compute_attention_specificity(att1, 2) + a2mean, _ = switched_conv.compute_attention_specificity(att2, 2) + a3mean, _ = switched_conv.compute_attention_specificity(att3, 2) + self.running_mean += (a1mean + a2mean + a3mean) / 3.0 + self.running_count += 1 + + return out * 0.2 + x + + def get_debug_values(self, step): + # Take the chance to update the temperature here. + temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step)) + self.set_temperature(temp) + + # Intentionally overwrite attention_temperature from other RRDB blocks; these should be synced. + val = {"RRDB_%i_attention_mean" % (self.counter,): self.running_mean / self.running_count, + "attention_temperature": temp} + self.running_count = 0 + self.running_mean = 0 + return val class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1, @@ -113,6 +146,13 @@ class RRDBNet(nn.Module): return (out,) + def get_debug_values(self, step): + val = {} + for block in self.RRDB_trunk._modules.values(): + if hasattr(block, "get_debug_values"): + val.update(block.get_debug_values(step)) + return val + # Variant of RRDBNet that is "assisted" by an external pretrained image classifier whose # intermediate layers have been splayed out, pixel-shuffled, and fed back in. class AssistedRRDBNet(nn.Module): diff --git a/codes/models/networks.py b/codes/models/networks.py index fa10056b..0671c6a8 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -36,7 +36,8 @@ def define_G(opt, net_key='network_G'): netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'], - init_temperature=opt_net['temperature'])) + init_temperature=opt_net['temperature'], + final_temperature_step=opt_net['temperature_final_step'])) elif which_model == 'ResGen': netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'], upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])