forked from mrq/DL-Art-School
Some convenience adjustments to ExtensibleTrainer
This commit is contained in:
parent
57fc3f490c
commit
9a17ade550
|
@ -253,9 +253,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
log.update(s.get_metrics())
|
||||
|
||||
# Some generators can do their own metric logging.
|
||||
for net in self.networks.values():
|
||||
for net_name, net in self.networks.items():
|
||||
if hasattr(net.module, "get_debug_values"):
|
||||
log.update(net.module.get_debug_values(step))
|
||||
log.update(net.module.get_debug_values(step, net_name))
|
||||
return log
|
||||
|
||||
def get_current_visuals(self, need_GT=True):
|
||||
|
|
|
@ -395,7 +395,7 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
prefix = "attention_map_%i_%%i.png" % (step,)
|
||||
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
||||
|
||||
def get_debug_values(self, step):
|
||||
def get_debug_values(self, step, net_name):
|
||||
temp = self.switches[0].switch.temperature
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
|
@ -527,7 +527,7 @@ class Spsr4(nn.Module):
|
|||
prefix = "attention_map_%i_%%i.png" % (step,)
|
||||
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
||||
|
||||
def get_debug_values(self, step):
|
||||
def get_debug_values(self, step, net_name):
|
||||
temp = self.switches[0].switch.temperature
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
|
@ -658,7 +658,7 @@ class Spsr5(nn.Module):
|
|||
prefix = "attention_map_%i_%%i.png" % (step,)
|
||||
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
||||
|
||||
def get_debug_values(self, step):
|
||||
def get_debug_values(self, step, net_name):
|
||||
temp = self.switches[0].switch.temperature
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
|
|
|
@ -210,7 +210,7 @@ class SSGr1(nn.Module):
|
|||
torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,)))
|
||||
|
||||
|
||||
def get_debug_values(self, step):
|
||||
def get_debug_values(self, step, net_name):
|
||||
temp = self.switches[0].switch.temperature
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
|
|
|
@ -105,6 +105,7 @@ class ConfigurableStep(Module):
|
|||
for k, v in state.items():
|
||||
local_state[k] = v[grad_accum_step]
|
||||
local_state.update(new_state)
|
||||
local_state['train_nets'] = str(self.get_networks_trained())
|
||||
|
||||
# Inject in any extra dependencies.
|
||||
for inj in self.injectors:
|
||||
|
|
Loading…
Reference in New Issue
Block a user