From 7659bd6818af25424bd7268fba5d35ffd97ba35d Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 14 Jul 2020 10:17:14 -0600 Subject: [PATCH] Fix temperature equation --- .../archs/SwitchedResidualGenerator_arch.py | 7 +++--- codes/utils/onnx_inference.py | 24 +++++++++++++------ 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 8a0b9de8..3e19fe38 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -294,9 +294,10 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module): def update_for_step(self, step, experiments_path='.'): if self.attentions: - temp = max(1, int( - self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step)) - if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1: + temp = max(1, + 1 + self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step) + if temp == 1 and self.heightened_final_step and step > self.final_temperature_step and \ + self.heightened_final_step != 1: # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. # without this, the attention specificity "spikes" incredibly fast in the last few iterations. h_steps_total = self.heightened_final_step - self.final_temperature_step diff --git a/codes/utils/onnx_inference.py b/codes/utils/onnx_inference.py index a2f214a4..56814b12 100644 --- a/codes/utils/onnx_inference.py +++ b/codes/utils/onnx_inference.py @@ -2,11 +2,21 @@ import onnx import numpy as np import time -model = onnx.load('../results/gen.onnx') +init_temperature = 10 +final_temperature_step = 50 +heightened_final_step = 90 +heightened_temp_min = .1 -outputs = {} -for n in model.graph.node: - for o in n.output: - outputs[o] = n - -res = 0 \ No newline at end of file +for step in range(100): + temp = max(1, 1 + init_temperature * (final_temperature_step - step) / final_temperature_step) + if temp == 1 and step > final_temperature_step and heightened_final_step and heightened_final_step != 1: + # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. + # without this, the attention specificity "spikes" incredibly fast in the last few iterations. + h_steps_total = heightened_final_step - final_temperature_step + h_steps_current = min(step - final_temperature_step, h_steps_total) + # The "gap" will represent the steps that need to be traveled as a linear function. + h_gap = 1 / heightened_temp_min + temp = h_gap * h_steps_current / h_steps_total + # Invert temperature to represent reality on this side of the curve + temp = 1 / temp + print("%i: %f" % (step, temp)) \ No newline at end of file