forked from mrq/DL-Art-School
24792bdb4f
Removed a lot of legacy stuff I have no intent on using again. Plan is to shape this repo into something more extensible (get it? hah!)
66 lines
2.4 KiB
Python
66 lines
2.4 KiB
Python
import math
|
|
from matplotlib import pyplot as plt
|
|
|
|
# Base class for weight schedulers. Holds weight at a fixed initial value.
|
|
class WeightScheduler:
|
|
def __init__(self, initial_weight):
|
|
self.initial_weight = initial_weight
|
|
|
|
def get_weight_for_step(self, step):
|
|
return self.initial_weight
|
|
|
|
|
|
class LinearDecayWeightScheduler(WeightScheduler):
|
|
def __init__(self, initial_weight, steps_to_decay, lower_bound, initial_step=0):
|
|
super(LinearDecayWeightScheduler, self).__init__(initial_weight)
|
|
self.steps_to_decay = steps_to_decay
|
|
self.lower_bound = lower_bound
|
|
self.initial_step = initial_step
|
|
self.decrease_per_step = (initial_weight - lower_bound) / self.steps_to_decay
|
|
|
|
def get_weight_for_step(self, step):
|
|
step = step - self.initial_step
|
|
if step < 0:
|
|
return self.initial_weight
|
|
return max(self.lower_bound, self.initial_weight - step * self.decrease_per_step)
|
|
|
|
|
|
class SinusoidalWeightScheduler(WeightScheduler):
|
|
def __init__(self, upper_weight, lower_weight, period_steps, initial_step=0):
|
|
super(SinusoidalWeightScheduler, self).__init__(upper_weight)
|
|
self.center = (upper_weight + lower_weight) / 2
|
|
self.amplitude = (upper_weight - lower_weight) / 2
|
|
self.period = period_steps
|
|
self.initial_step = initial_step
|
|
|
|
def get_weight_for_step(self, step):
|
|
step = step - self.initial_step
|
|
if step < 0:
|
|
return self.initial_weight
|
|
# Use cosine because it starts at y=1 for x=0.
|
|
return math.cos(step * math.pi * 2 / self.period) * self.amplitude + self.center
|
|
|
|
|
|
def get_scheduler_for_opt(opt):
|
|
if opt['type'] == 'fixed':
|
|
return WeightScheduler(opt['weight'])
|
|
elif opt['type'] == 'linear_decay':
|
|
return LinearDecayWeightScheduler(opt['initial_weight'], opt['steps'], opt['lower_bound'], opt['start_step'])
|
|
elif opt['type'] == 'sinusoidal':
|
|
return SinusoidalWeightScheduler(opt['upper_weight'], opt['lower_weight'], opt['period'], opt['start_step'])
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
# Do some testing.
|
|
if __name__ == "__main__":
|
|
#sched = SinusoidalWeightScheduler(1, .1, 50, 10)
|
|
sched = LinearDecayWeightScheduler(1, 150, .1, 20)
|
|
|
|
x = []
|
|
y = []
|
|
for s in range(200):
|
|
x.append(s)
|
|
y.append(sched.get_weight_for_step(s))
|
|
plt.plot(x, y)
|
|
plt.show() |