diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 64940b06..2484c939 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -253,9 +253,9 @@ class ExtensibleTrainer(BaseModel): log.update(s.get_metrics()) # Some generators can do their own metric logging. - for net in self.networks.values(): + for net_name, net in self.networks.items(): if hasattr(net.module, "get_debug_values"): - log.update(net.module.get_debug_values(step)) + log.update(net.module.get_debug_values(step, net_name)) return log def get_current_visuals(self, need_GT=True): diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 97905b1d..0361ae4a 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -395,7 +395,7 @@ class SwitchedSpsrWithRef2(nn.Module): prefix = "attention_map_%i_%%i.png" % (step,) [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] - def get_debug_values(self, step): + def get_debug_values(self, step, net_name): temp = self.switches[0].switch.temperature mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] means = [i[0] for i in mean_hists] @@ -527,7 +527,7 @@ class Spsr4(nn.Module): prefix = "attention_map_%i_%%i.png" % (step,) [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] - def get_debug_values(self, step): + def get_debug_values(self, step, net_name): temp = self.switches[0].switch.temperature mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] means = [i[0] for i in mean_hists] @@ -658,7 +658,7 @@ class Spsr5(nn.Module): prefix = "attention_map_%i_%%i.png" % (step,) [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] - def get_debug_values(self, step): + def get_debug_values(self, step, net_name): temp = self.switches[0].switch.temperature mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] means = [i[0] for i in mean_hists] diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index 3fd3487d..c5fd5923 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -210,7 +210,7 @@ class SSGr1(nn.Module): torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) - def get_debug_values(self, step): + def get_debug_values(self, step, net_name): temp = self.switches[0].switch.temperature mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] means = [i[0] for i in mean_hists] diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index ad73c621..f2226d69 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -105,6 +105,7 @@ class ConfigurableStep(Module): for k, v in state.items(): local_state[k] = v[grad_accum_step] local_state.update(new_state) + local_state['train_nets'] = str(self.get_networks_trained()) # Inject in any extra dependencies. for inj in self.injectors: