From 4c0f770f2a6f9078dbc4b214255e7a466fddf1d2 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 12 Jul 2020 11:02:50 -0600 Subject: [PATCH] Fix inverted temperature curve bug --- codes/models/archs/SwitchedResidualGenerator_arch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index ef63ab60..c214819d 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -1,6 +1,6 @@ import torch from torch import nn -from switched_conv import BareConvSwitch, compute_attention_specificity +from switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm import torch.nn.functional as F import functools from collections import OrderedDict @@ -208,7 +208,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. # without this, the attention specificity "spikes" incredibly fast in the last few iterations. h_steps_total = self.heightened_final_step - self.final_temperature_step - h_steps_current = min(step - self.final_temperature_step, h_steps_total) + h_steps_current = max(min(step - self.final_temperature_step, h_steps_total), 1) # The "gap" will represent the steps that need to be traveled as a linear function. h_gap = 1 / self.heightened_temp_min temp = h_gap * h_steps_current / h_steps_total