diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 6aa44df2..0a08a352 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -133,14 +133,18 @@ class ConfigurableSwitchedResidualGenerator(nn.Module): if self.attentions: temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step)) if temp == 1: - # Enter the linear function regime between self.final_temperature_step and self.heightened_final_step + # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. + # without this, the attention specificity "spikes" incredibly fast in the last few iterations. h_steps_total = self.heightened_final_step - self.final_temperature_step h_steps_current = min(step - self.final_temperature_step, h_steps_total) - h_gap = 1 - self.heightened_temp_min - temp = 1 - h_gap * h_steps_current / h_steps_total + # The "gap" will represent the steps that need to be traveled as a linear function. + h_gap = 1 / self.heightened_temp_min + temp = h_gap * h_steps_current / h_steps_total + # Invert temperature to represent reality on this side of the curve + temp = 1 / temp self.set_temperature(temp) if step % 50 == 0: - [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,), l_mult=float(self.transformation_counts[i]/4)) for i in range(len(self.switches))] + [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,)) for i in range(len(self.switches))] def get_debug_values(self, step): temp = self.switches[0].switch.temperature