Allow switched RRDBNet to record metrics and decay temperature

This commit is contained in:
James Betker 2020-06-08 11:10:38 -06:00
parent ae3301c0ea
commit 786a4288d6
3 changed files with 52 additions and 6 deletions

View File

@ -439,12 +439,17 @@ class SRGANModel(BaseModel):
self.netG.train() self.netG.train()
# Fetches a summary of the log. # Fetches a summary of the log.
def get_current_log(self): def get_current_log(self, step):
return_log = {} return_log = {}
for k in self.log_dict.keys(): for k in self.log_dict.keys():
if not isinstance(self.log_dict[k], list): if not isinstance(self.log_dict[k], list):
continue continue
return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k]) return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k])
# Some generators can do their own metric logging.
if hasattr(self.netG.module, "get_debug_values"):
return_log.update(self.netG.module.get_debug_values(step))
return return_log return return_log
def get_current_visuals(self, need_GT=True): def get_current_visuals(self, need_GT=True):

View File

@ -61,16 +61,49 @@ class RRDB(nn.Module):
return out * 0.2 + x return out * 0.2 + x
class AttentiveRRDB(RRDB): class AttentiveRRDB(RRDB):
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1): counter = 0
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1):
super(RRDB, self).__init__() super(RRDB, self).__init__()
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature) self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature) self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature) self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
self.init_temperature = init_temperature
self.final_temperature_step = final_temperature_step
self.running_mean = 0
self.running_count = 0
self.counter = AttentiveRRDB.counter
AttentiveRRDB.counter += 1
def set_temperature(self, temp): def set_temperature(self, temp):
self.RDB1.set_temperature(temp) self.RDB1.switcher.set_attention_temperature(temp)
self.RDB2.set_temperature(temp) self.RDB2.switcher.set_attention_temperature(temp)
self.RDB3.set_temperature(temp) self.RDB3.switcher.set_attention_temperature(temp)
def forward(self, x):
out, att1 = self.RDB1(x, True)
out, att2 = self.RDB2(out, True)
out, att3 = self.RDB3(out, True)
a1mean, _ = switched_conv.compute_attention_specificity(att1, 2)
a2mean, _ = switched_conv.compute_attention_specificity(att2, 2)
a3mean, _ = switched_conv.compute_attention_specificity(att3, 2)
self.running_mean += (a1mean + a2mean + a3mean) / 3.0
self.running_count += 1
return out * 0.2 + x
def get_debug_values(self, step):
# Take the chance to update the temperature here.
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
self.set_temperature(temp)
# Intentionally overwrite attention_temperature from other RRDB blocks; these should be synced.
val = {"RRDB_%i_attention_mean" % (self.counter,): self.running_mean / self.running_count,
"attention_temperature": temp}
self.running_count = 0
self.running_mean = 0
return val
class RRDBNet(nn.Module): class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1, def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1,
@ -113,6 +146,13 @@ class RRDBNet(nn.Module):
return (out,) return (out,)
def get_debug_values(self, step):
val = {}
for block in self.RRDB_trunk._modules.values():
if hasattr(block, "get_debug_values"):
val.update(block.get_debug_values(step))
return val
# Variant of RRDBNet that is "assisted" by an external pretrained image classifier whose # Variant of RRDBNet that is "assisted" by an external pretrained image classifier whose
# intermediate layers have been splayed out, pixel-shuffled, and fed back in. # intermediate layers have been splayed out, pixel-shuffled, and fed back in.
class AssistedRRDBNet(nn.Module): class AssistedRRDBNet(nn.Module):

View File

@ -36,7 +36,8 @@ def define_G(opt, net_key='network_G'):
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, nf=opt_net['nf'], nb=opt_net['nb'], scale=scale,
rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'], rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
init_temperature=opt_net['temperature'])) init_temperature=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step']))
elif which_model == 'ResGen': elif which_model == 'ResGen':
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'], netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf']) upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])