DL-Art-School/codes/data/weight_scheduler.py

64 lines
2.3 KiB
Python

import math
from matplotlib import pyplot as plt
# Base class for weight schedulers. Holds weight at a fixed initial value.
class WeightScheduler:
def __init__(self, initial_weight):
self.initial_weight = initial_weight
def get_weight_for_step(self, step):
return self.initial_weight
class LinearDecayWeightScheduler(WeightScheduler):
def __init__(self, initial_weight, steps_to_decay, lower_bound, initial_step=0):
super(LinearDecayWeightScheduler, self).__init__(initial_weight)
self.steps_to_decay = steps_to_decay
self.lower_bound = lower_bound
self.initial_step = initial_step
self.decrease_per_step = (initial_weight - lower_bound) / self.steps_to_decay
def get_weight_for_step(self, step):
step = step - self.initial_step
if step < 0:
return self.initial_weight
return max(self.lower_bound, self.initial_weight - step * self.decrease_per_step)
class SinusoidalWeightScheduler(WeightScheduler):
def __init__(self, upper_weight, lower_weight, period_steps, initial_step=0):
super(SinusoidalWeightScheduler, self).__init__(upper_weight)
self.center = (upper_weight + lower_weight) / 2
self.amplitude = (upper_weight - lower_weight) / 2
self.period = period_steps
self.initial_step = initial_step
def get_weight_for_step(self, step):
step = step - self.initial_step
if step < 0:
return self.initial_weight
# Use cosine because it starts at y=1 for x=0.
return math.cos(step * math.pi * 2 / self.period) * self.amplitude + self.center
def get_scheduler_for_opt(opt):
if opt['type'] == 'fixed':
return WeightScheduler(opt['weight'])
elif opt['type'] == 'linear_decay':
return LinearDecayWeightScheduler(opt['initial_weight'], opt['steps'], opt['lower_bound'], opt['start_step'])
elif opt['type'] == 'sinusoidal':
return SinusoidalWeightScheduler(opt['upper_weight'], opt['lower_weight'], opt['period'], opt['start_step'])
# Do some testing.
if __name__ == "__main__":
#sched = SinusoidalWeightScheduler(1, .1, 50, 10)
sched = LinearDecayWeightScheduler(1, 150, .1, 20)
x = []
y = []
for s in range(200):
x.append(s)
y.append(sched.get_weight_for_step(s))
plt.plot(x, y)
plt.show()