forked from mrq/DL-Art-School
64 lines
2.3 KiB
Python
64 lines
2.3 KiB
Python
|
import math
|
||
|
from matplotlib import pyplot as plt
|
||
|
|
||
|
# Base class for weight schedulers. Holds weight at a fixed initial value.
|
||
|
class WeightScheduler:
|
||
|
def __init__(self, initial_weight):
|
||
|
self.initial_weight = initial_weight
|
||
|
|
||
|
def get_weight_for_step(self, step):
|
||
|
return self.initial_weight
|
||
|
|
||
|
|
||
|
class LinearDecayWeightScheduler(WeightScheduler):
|
||
|
def __init__(self, initial_weight, steps_to_decay, lower_bound, initial_step=0):
|
||
|
super(LinearDecayWeightScheduler, self).__init__(initial_weight)
|
||
|
self.steps_to_decay = steps_to_decay
|
||
|
self.lower_bound = lower_bound
|
||
|
self.initial_step = initial_step
|
||
|
self.decrease_per_step = (initial_weight - lower_bound) / self.steps_to_decay
|
||
|
|
||
|
def get_weight_for_step(self, step):
|
||
|
step = step - self.initial_step
|
||
|
if step < 0:
|
||
|
return self.initial_weight
|
||
|
return max(self.lower_bound, self.initial_weight - step * self.decrease_per_step)
|
||
|
|
||
|
|
||
|
class SinusoidalWeightScheduler(WeightScheduler):
|
||
|
def __init__(self, upper_weight, lower_weight, period_steps, initial_step=0):
|
||
|
super(SinusoidalWeightScheduler, self).__init__(upper_weight)
|
||
|
self.center = (upper_weight + lower_weight) / 2
|
||
|
self.amplitude = (upper_weight - lower_weight) / 2
|
||
|
self.period = period_steps
|
||
|
self.initial_step = initial_step
|
||
|
|
||
|
def get_weight_for_step(self, step):
|
||
|
step = step - self.initial_step
|
||
|
if step < 0:
|
||
|
return self.initial_weight
|
||
|
# Use cosine because it starts at y=1 for x=0.
|
||
|
return math.cos(step * math.pi * 2 / self.period) * self.amplitude + self.center
|
||
|
|
||
|
|
||
|
def get_scheduler_for_opt(opt):
|
||
|
if opt['type'] == 'fixed':
|
||
|
return WeightScheduler(opt['weight'])
|
||
|
elif opt['type'] == 'linear_decay':
|
||
|
return LinearDecayWeightScheduler(opt['initial_weight'], opt['steps'], opt['lower_bound'], opt['start_step'])
|
||
|
elif opt['type'] == 'sinusoidal':
|
||
|
return SinusoidalWeightScheduler(opt['upper_weight'], opt['lower_weight'], opt['period'], opt['start_step'])
|
||
|
|
||
|
|
||
|
# Do some testing.
|
||
|
if __name__ == "__main__":
|
||
|
#sched = SinusoidalWeightScheduler(1, .1, 50, 10)
|
||
|
sched = LinearDecayWeightScheduler(1, 150, .1, 20)
|
||
|
|
||
|
x = []
|
||
|
y = []
|
||
|
for s in range(200):
|
||
|
x.append(s)
|
||
|
y.append(sched.get_weight_for_step(s))
|
||
|
plt.plot(x, y)
|
||
|
plt.show()
|