Add an abstract, configurabler weight scheduling class and apply it to the feature weight

This commit is contained in:
James Betker 2020-07-23 17:03:54 -06:00
parent 9ccf771629
commit c50cce2a62
2 changed files with 85 additions and 16 deletions

View File

@ -0,0 +1,64 @@
import math
from matplotlib import pyplot as plt
# Base class for weight schedulers. Holds weight at a fixed initial value.
class WeightScheduler:
def __init__(self, initial_weight):
self.initial_weight = initial_weight
def get_weight_for_step(self, step):
return self.initial_weight
class LinearDecayWeightScheduler(WeightScheduler):
def __init__(self, initial_weight, steps_to_decay, lower_bound, initial_step=0):
super(LinearDecayWeightScheduler, self).__init__(initial_weight)
self.steps_to_decay = steps_to_decay
self.lower_bound = lower_bound
self.initial_step = initial_step
self.decrease_per_step = (initial_weight - lower_bound) / self.steps_to_decay
def get_weight_for_step(self, step):
step = step - self.initial_step
if step < 0:
return self.initial_weight
return max(self.lower_bound, self.initial_weight - step * self.decrease_per_step)
class SinusoidalWeightScheduler(WeightScheduler):
def __init__(self, upper_weight, lower_weight, period_steps, initial_step=0):
super(SinusoidalWeightScheduler, self).__init__(upper_weight)
self.center = (upper_weight + lower_weight) / 2
self.amplitude = (upper_weight - lower_weight) / 2
self.period = period_steps
self.initial_step = initial_step
def get_weight_for_step(self, step):
step = step - self.initial_step
if step < 0:
return self.initial_weight
# Use cosine because it starts at y=1 for x=0.
return math.cos(step * math.pi * 2 / self.period) * self.amplitude + self.center
def get_scheduler_for_opt(opt):
if opt['type'] == 'fixed':
return WeightScheduler(opt['weight'])
elif opt['type'] == 'linear_decay':
return LinearDecayWeightScheduler(opt['initial_weight'], opt['steps'], opt['lower_bound'], opt['start_step'])
elif opt['type'] == 'sinusoidal':
return SinusoidalWeightScheduler(opt['upper_weight'], opt['lower_weight'], opt['period'], opt['start_step'])
# Do some testing.
if __name__ == "__main__":
#sched = SinusoidalWeightScheduler(1, .1, 50, 10)
sched = LinearDecayWeightScheduler(1, 150, .1, 20)
x = []
y = []
for s in range(200):
x.append(s)
y.append(sched.get_weight_for_step(s))
plt.plot(x, y)
plt.show()

View File

@ -8,6 +8,7 @@ import models.lr_scheduler as lr_scheduler
from models.base_model import BaseModel
from models.loss import GANLoss
from apex import amp
from data.weight_scheduler import get_scheduler_for_opt
import torch.nn.functional as F
import glob
import random
@ -61,7 +62,22 @@ class SRGANModel(BaseModel):
self.cri_pix = None
# G feature loss
if train_opt['feature_weight'] > 0:
if train_opt['feature_weight'] and train_opt['feature_weight'] > 0:
# For backwards compatibility, use a scheduler definition instead. Remove this at some point.
l_fea_type = train_opt['feature_criterion']
if l_fea_type == 'l1':
self.cri_fea = nn.L1Loss().to(self.device)
elif l_fea_type == 'l2':
self.cri_fea = nn.MSELoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
sched_params = {
'type': 'fixed',
'weight': train_opt['feature_weight']
}
self.l_fea_sched = get_scheduler_for_opt(sched_params)
elif train_opt['feature_scheduler']:
self.l_fea_sched = get_scheduler_for_opt(train_opt['feature_scheduler'])
l_fea_type = train_opt['feature_criterion']
if l_fea_type == 'l1':
self.cri_fea = nn.L1Loss().to(self.device)
@ -69,13 +85,6 @@ class SRGANModel(BaseModel):
self.cri_fea = nn.MSELoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
self.l_fea_w = train_opt['feature_weight']
self.l_fea_w_start = train_opt['feature_weight']
self.l_fea_w_decay_start = train_opt['feature_weight_decay_start']
self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps']
self.l_fea_w_minimum = train_opt['feature_weight_minimum']
if self.l_fea_w_decay_start:
self.l_fea_w_decay_step_size = (self.l_fea_w - self.l_fea_w_minimum) / (self.l_fea_w_decay_steps)
else:
logger.info('Remove feature loss.')
self.cri_fea = None
@ -283,19 +292,15 @@ class SRGANModel(BaseModel):
if self.cri_fea: # feature loss
real_fea = self.netF(pix).detach()
fake_fea = self.netF(fea_GenOut)
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_fea_log = l_g_fea / self.l_fea_w
fea_w = self.l_fea_sched.get_weight_for_step(step)
l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea)
l_g_fea_log = l_g_fea / fea_w
l_g_total += l_g_fea
if _profile:
print("Fea forward %f" % (time() - _t,))
_t = time()
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
# in the resultant image.
if self.l_fea_w_decay_start and step > self.l_fea_w_decay_start:
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w_start - self.l_fea_w_decay_step_size * (step - self.l_fea_w_decay_start))
# Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931
# Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
@ -507,7 +512,7 @@ class SRGANModel(BaseModel):
if self.cri_pix:
self.add_log_entry('l_g_pix', l_g_pix_log.item())
if self.cri_fea:
self.add_log_entry('feature_weight', self.l_fea_w)
self.add_log_entry('feature_weight', fea_w)
self.add_log_entry('l_g_fea', l_g_fea_log.item())
if self.l_gan_w > 0:
self.add_log_entry('l_g_gan', l_g_gan_log.item())