Some convenience adjustments to ExtensibleTrainer

This commit is contained in:
James Betker 2020-09-17 21:05:32 -06:00
parent 57fc3f490c
commit 9a17ade550
4 changed files with 7 additions and 6 deletions

View File

@ -253,9 +253,9 @@ class ExtensibleTrainer(BaseModel):
log.update(s.get_metrics())
# Some generators can do their own metric logging.
for net in self.networks.values():
for net_name, net in self.networks.items():
if hasattr(net.module, "get_debug_values"):
log.update(net.module.get_debug_values(step))
log.update(net.module.get_debug_values(step, net_name))
return log
def get_current_visuals(self, need_GT=True):

View File

@ -395,7 +395,7 @@ class SwitchedSpsrWithRef2(nn.Module):
prefix = "attention_map_%i_%%i.png" % (step,)
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
def get_debug_values(self, step):
def get_debug_values(self, step, net_name):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
@ -527,7 +527,7 @@ class Spsr4(nn.Module):
prefix = "attention_map_%i_%%i.png" % (step,)
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
def get_debug_values(self, step):
def get_debug_values(self, step, net_name):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
@ -658,7 +658,7 @@ class Spsr5(nn.Module):
prefix = "attention_map_%i_%%i.png" % (step,)
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
def get_debug_values(self, step):
def get_debug_values(self, step, net_name):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]

View File

@ -210,7 +210,7 @@ class SSGr1(nn.Module):
torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,)))
def get_debug_values(self, step):
def get_debug_values(self, step, net_name):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]

View File

@ -105,6 +105,7 @@ class ConfigurableStep(Module):
for k, v in state.items():
local_state[k] = v[grad_accum_step]
local_state.update(new_state)
local_state['train_nets'] = str(self.get_networks_trained())
# Inject in any extra dependencies.
for inj in self.injectors: