forked from mrq/DL-Art-School
Fix temperature equation
This commit is contained in:
parent
853468ef82
commit
7659bd6818
|
@ -294,9 +294,10 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module):
|
|||
|
||||
def update_for_step(self, step, experiments_path='.'):
|
||||
if self.attentions:
|
||||
temp = max(1, int(
|
||||
self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
|
||||
if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
|
||||
temp = max(1,
|
||||
1 + self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step)
|
||||
if temp == 1 and self.heightened_final_step and step > self.final_temperature_step and \
|
||||
self.heightened_final_step != 1:
|
||||
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
|
||||
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
|
||||
h_steps_total = self.heightened_final_step - self.final_temperature_step
|
||||
|
|
|
@ -2,11 +2,21 @@ import onnx
|
|||
import numpy as np
|
||||
import time
|
||||
|
||||
model = onnx.load('../results/gen.onnx')
|
||||
init_temperature = 10
|
||||
final_temperature_step = 50
|
||||
heightened_final_step = 90
|
||||
heightened_temp_min = .1
|
||||
|
||||
outputs = {}
|
||||
for n in model.graph.node:
|
||||
for o in n.output:
|
||||
outputs[o] = n
|
||||
|
||||
res = 0
|
||||
for step in range(100):
|
||||
temp = max(1, 1 + init_temperature * (final_temperature_step - step) / final_temperature_step)
|
||||
if temp == 1 and step > final_temperature_step and heightened_final_step and heightened_final_step != 1:
|
||||
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
|
||||
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
|
||||
h_steps_total = heightened_final_step - final_temperature_step
|
||||
h_steps_current = min(step - final_temperature_step, h_steps_total)
|
||||
# The "gap" will represent the steps that need to be traveled as a linear function.
|
||||
h_gap = 1 / heightened_temp_min
|
||||
temp = h_gap * h_steps_current / h_steps_total
|
||||
# Invert temperature to represent reality on this side of the curve
|
||||
temp = 1 / temp
|
||||
print("%i: %f" % (step, temp))
|
Loading…
Reference in New Issue
Block a user