Allow switched RRDBNet to record metrics and decay temperature
This commit is contained in:
parent
ae3301c0ea
commit
786a4288d6
|
@ -439,12 +439,17 @@ class SRGANModel(BaseModel):
|
|||
self.netG.train()
|
||||
|
||||
# Fetches a summary of the log.
|
||||
def get_current_log(self):
|
||||
def get_current_log(self, step):
|
||||
return_log = {}
|
||||
for k in self.log_dict.keys():
|
||||
if not isinstance(self.log_dict[k], list):
|
||||
continue
|
||||
return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k])
|
||||
|
||||
# Some generators can do their own metric logging.
|
||||
if hasattr(self.netG.module, "get_debug_values"):
|
||||
return_log.update(self.netG.module.get_debug_values(step))
|
||||
|
||||
return return_log
|
||||
|
||||
def get_current_visuals(self, need_GT=True):
|
||||
|
|
|
@ -61,16 +61,49 @@ class RRDB(nn.Module):
|
|||
return out * 0.2 + x
|
||||
|
||||
class AttentiveRRDB(RRDB):
|
||||
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1):
|
||||
counter = 0
|
||||
|
||||
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
||||
self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
||||
self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
||||
self.init_temperature = init_temperature
|
||||
self.final_temperature_step = final_temperature_step
|
||||
self.running_mean = 0
|
||||
self.running_count = 0
|
||||
self.counter = AttentiveRRDB.counter
|
||||
AttentiveRRDB.counter += 1
|
||||
|
||||
def set_temperature(self, temp):
|
||||
self.RDB1.set_temperature(temp)
|
||||
self.RDB2.set_temperature(temp)
|
||||
self.RDB3.set_temperature(temp)
|
||||
self.RDB1.switcher.set_attention_temperature(temp)
|
||||
self.RDB2.switcher.set_attention_temperature(temp)
|
||||
self.RDB3.switcher.set_attention_temperature(temp)
|
||||
|
||||
def forward(self, x):
|
||||
out, att1 = self.RDB1(x, True)
|
||||
out, att2 = self.RDB2(out, True)
|
||||
out, att3 = self.RDB3(out, True)
|
||||
|
||||
a1mean, _ = switched_conv.compute_attention_specificity(att1, 2)
|
||||
a2mean, _ = switched_conv.compute_attention_specificity(att2, 2)
|
||||
a3mean, _ = switched_conv.compute_attention_specificity(att3, 2)
|
||||
self.running_mean += (a1mean + a2mean + a3mean) / 3.0
|
||||
self.running_count += 1
|
||||
|
||||
return out * 0.2 + x
|
||||
|
||||
def get_debug_values(self, step):
|
||||
# Take the chance to update the temperature here.
|
||||
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
|
||||
self.set_temperature(temp)
|
||||
|
||||
# Intentionally overwrite attention_temperature from other RRDB blocks; these should be synced.
|
||||
val = {"RRDB_%i_attention_mean" % (self.counter,): self.running_mean / self.running_count,
|
||||
"attention_temperature": temp}
|
||||
self.running_count = 0
|
||||
self.running_mean = 0
|
||||
return val
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1,
|
||||
|
@ -113,6 +146,13 @@ class RRDBNet(nn.Module):
|
|||
|
||||
return (out,)
|
||||
|
||||
def get_debug_values(self, step):
|
||||
val = {}
|
||||
for block in self.RRDB_trunk._modules.values():
|
||||
if hasattr(block, "get_debug_values"):
|
||||
val.update(block.get_debug_values(step))
|
||||
return val
|
||||
|
||||
# Variant of RRDBNet that is "assisted" by an external pretrained image classifier whose
|
||||
# intermediate layers have been splayed out, pixel-shuffled, and fed back in.
|
||||
class AssistedRRDBNet(nn.Module):
|
||||
|
|
|
@ -36,7 +36,8 @@ def define_G(opt, net_key='network_G'):
|
|||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale,
|
||||
rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
init_temperature=opt_net['temperature']))
|
||||
init_temperature=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step']))
|
||||
elif which_model == 'ResGen':
|
||||
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
|
||||
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])
|
||||
|
|
Loading…
Reference in New Issue
Block a user