diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index ef63ab60..c214819d 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -1,6 +1,6 @@ import torch from torch import nn -from switched_conv import BareConvSwitch, compute_attention_specificity +from switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm import torch.nn.functional as F import functools from collections import OrderedDict @@ -208,7 +208,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. # without this, the attention specificity "spikes" incredibly fast in the last few iterations. h_steps_total = self.heightened_final_step - self.final_temperature_step - h_steps_current = min(step - self.final_temperature_step, h_steps_total) + h_steps_current = max(min(step - self.final_temperature_step, h_steps_total), 1) # The "gap" will represent the steps that need to be traveled as a linear function. h_gap = 1 / self.heightened_temp_min temp = h_gap * h_steps_current / h_steps_total