forked from mrq/DL-Art-School
Fix inverted temperature curve bug
This commit is contained in:
parent
14d23b9d20
commit
4c0f770f2a
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from switched_conv import BareConvSwitch, compute_attention_specificity
|
from switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import functools
|
import functools
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
@ -208,7 +208,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
|
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
|
||||||
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
|
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
|
||||||
h_steps_total = self.heightened_final_step - self.final_temperature_step
|
h_steps_total = self.heightened_final_step - self.final_temperature_step
|
||||||
h_steps_current = min(step - self.final_temperature_step, h_steps_total)
|
h_steps_current = max(min(step - self.final_temperature_step, h_steps_total), 1)
|
||||||
# The "gap" will represent the steps that need to be traveled as a linear function.
|
# The "gap" will represent the steps that need to be traveled as a linear function.
|
||||||
h_gap = 1 / self.heightened_temp_min
|
h_gap = 1 / self.heightened_temp_min
|
||||||
temp = h_gap * h_steps_current / h_steps_total
|
temp = h_gap * h_steps_current / h_steps_total
|
||||||
|
|
Loading…
Reference in New Issue
Block a user