forked from mrq/DL-Art-School
Allow switched RRDBNet to record metrics and decay temperature
This commit is contained in:
parent
ae3301c0ea
commit
786a4288d6
|
@ -439,12 +439,17 @@ class SRGANModel(BaseModel):
|
||||||
self.netG.train()
|
self.netG.train()
|
||||||
|
|
||||||
# Fetches a summary of the log.
|
# Fetches a summary of the log.
|
||||||
def get_current_log(self):
|
def get_current_log(self, step):
|
||||||
return_log = {}
|
return_log = {}
|
||||||
for k in self.log_dict.keys():
|
for k in self.log_dict.keys():
|
||||||
if not isinstance(self.log_dict[k], list):
|
if not isinstance(self.log_dict[k], list):
|
||||||
continue
|
continue
|
||||||
return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k])
|
return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k])
|
||||||
|
|
||||||
|
# Some generators can do their own metric logging.
|
||||||
|
if hasattr(self.netG.module, "get_debug_values"):
|
||||||
|
return_log.update(self.netG.module.get_debug_values(step))
|
||||||
|
|
||||||
return return_log
|
return return_log
|
||||||
|
|
||||||
def get_current_visuals(self, need_GT=True):
|
def get_current_visuals(self, need_GT=True):
|
||||||
|
|
|
@ -61,16 +61,49 @@ class RRDB(nn.Module):
|
||||||
return out * 0.2 + x
|
return out * 0.2 + x
|
||||||
|
|
||||||
class AttentiveRRDB(RRDB):
|
class AttentiveRRDB(RRDB):
|
||||||
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1):
|
counter = 0
|
||||||
|
|
||||||
|
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1):
|
||||||
super(RRDB, self).__init__()
|
super(RRDB, self).__init__()
|
||||||
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
||||||
self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
||||||
self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
||||||
|
self.init_temperature = init_temperature
|
||||||
|
self.final_temperature_step = final_temperature_step
|
||||||
|
self.running_mean = 0
|
||||||
|
self.running_count = 0
|
||||||
|
self.counter = AttentiveRRDB.counter
|
||||||
|
AttentiveRRDB.counter += 1
|
||||||
|
|
||||||
def set_temperature(self, temp):
|
def set_temperature(self, temp):
|
||||||
self.RDB1.set_temperature(temp)
|
self.RDB1.switcher.set_attention_temperature(temp)
|
||||||
self.RDB2.set_temperature(temp)
|
self.RDB2.switcher.set_attention_temperature(temp)
|
||||||
self.RDB3.set_temperature(temp)
|
self.RDB3.switcher.set_attention_temperature(temp)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out, att1 = self.RDB1(x, True)
|
||||||
|
out, att2 = self.RDB2(out, True)
|
||||||
|
out, att3 = self.RDB3(out, True)
|
||||||
|
|
||||||
|
a1mean, _ = switched_conv.compute_attention_specificity(att1, 2)
|
||||||
|
a2mean, _ = switched_conv.compute_attention_specificity(att2, 2)
|
||||||
|
a3mean, _ = switched_conv.compute_attention_specificity(att3, 2)
|
||||||
|
self.running_mean += (a1mean + a2mean + a3mean) / 3.0
|
||||||
|
self.running_count += 1
|
||||||
|
|
||||||
|
return out * 0.2 + x
|
||||||
|
|
||||||
|
def get_debug_values(self, step):
|
||||||
|
# Take the chance to update the temperature here.
|
||||||
|
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
|
||||||
|
self.set_temperature(temp)
|
||||||
|
|
||||||
|
# Intentionally overwrite attention_temperature from other RRDB blocks; these should be synced.
|
||||||
|
val = {"RRDB_%i_attention_mean" % (self.counter,): self.running_mean / self.running_count,
|
||||||
|
"attention_temperature": temp}
|
||||||
|
self.running_count = 0
|
||||||
|
self.running_mean = 0
|
||||||
|
return val
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
class RRDBNet(nn.Module):
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1,
|
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1,
|
||||||
|
@ -113,6 +146,13 @@ class RRDBNet(nn.Module):
|
||||||
|
|
||||||
return (out,)
|
return (out,)
|
||||||
|
|
||||||
|
def get_debug_values(self, step):
|
||||||
|
val = {}
|
||||||
|
for block in self.RRDB_trunk._modules.values():
|
||||||
|
if hasattr(block, "get_debug_values"):
|
||||||
|
val.update(block.get_debug_values(step))
|
||||||
|
return val
|
||||||
|
|
||||||
# Variant of RRDBNet that is "assisted" by an external pretrained image classifier whose
|
# Variant of RRDBNet that is "assisted" by an external pretrained image classifier whose
|
||||||
# intermediate layers have been splayed out, pixel-shuffled, and fed back in.
|
# intermediate layers have been splayed out, pixel-shuffled, and fed back in.
|
||||||
class AssistedRRDBNet(nn.Module):
|
class AssistedRRDBNet(nn.Module):
|
||||||
|
|
|
@ -36,7 +36,8 @@ def define_G(opt, net_key='network_G'):
|
||||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale,
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale,
|
||||||
rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||||
init_temperature=opt_net['temperature']))
|
init_temperature=opt_net['temperature'],
|
||||||
|
final_temperature_step=opt_net['temperature_final_step']))
|
||||||
elif which_model == 'ResGen':
|
elif which_model == 'ResGen':
|
||||||
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
|
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
|
||||||
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])
|
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user