Add an abstract, configurabler weight scheduling class and apply it to the feature weight
This commit is contained in:
parent
9ccf771629
commit
c50cce2a62
64
codes/data/weight_scheduler.py
Normal file
64
codes/data/weight_scheduler.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import math
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
# Base class for weight schedulers. Holds weight at a fixed initial value.
|
||||
class WeightScheduler:
|
||||
def __init__(self, initial_weight):
|
||||
self.initial_weight = initial_weight
|
||||
|
||||
def get_weight_for_step(self, step):
|
||||
return self.initial_weight
|
||||
|
||||
|
||||
class LinearDecayWeightScheduler(WeightScheduler):
|
||||
def __init__(self, initial_weight, steps_to_decay, lower_bound, initial_step=0):
|
||||
super(LinearDecayWeightScheduler, self).__init__(initial_weight)
|
||||
self.steps_to_decay = steps_to_decay
|
||||
self.lower_bound = lower_bound
|
||||
self.initial_step = initial_step
|
||||
self.decrease_per_step = (initial_weight - lower_bound) / self.steps_to_decay
|
||||
|
||||
def get_weight_for_step(self, step):
|
||||
step = step - self.initial_step
|
||||
if step < 0:
|
||||
return self.initial_weight
|
||||
return max(self.lower_bound, self.initial_weight - step * self.decrease_per_step)
|
||||
|
||||
|
||||
class SinusoidalWeightScheduler(WeightScheduler):
|
||||
def __init__(self, upper_weight, lower_weight, period_steps, initial_step=0):
|
||||
super(SinusoidalWeightScheduler, self).__init__(upper_weight)
|
||||
self.center = (upper_weight + lower_weight) / 2
|
||||
self.amplitude = (upper_weight - lower_weight) / 2
|
||||
self.period = period_steps
|
||||
self.initial_step = initial_step
|
||||
|
||||
def get_weight_for_step(self, step):
|
||||
step = step - self.initial_step
|
||||
if step < 0:
|
||||
return self.initial_weight
|
||||
# Use cosine because it starts at y=1 for x=0.
|
||||
return math.cos(step * math.pi * 2 / self.period) * self.amplitude + self.center
|
||||
|
||||
|
||||
def get_scheduler_for_opt(opt):
|
||||
if opt['type'] == 'fixed':
|
||||
return WeightScheduler(opt['weight'])
|
||||
elif opt['type'] == 'linear_decay':
|
||||
return LinearDecayWeightScheduler(opt['initial_weight'], opt['steps'], opt['lower_bound'], opt['start_step'])
|
||||
elif opt['type'] == 'sinusoidal':
|
||||
return SinusoidalWeightScheduler(opt['upper_weight'], opt['lower_weight'], opt['period'], opt['start_step'])
|
||||
|
||||
|
||||
# Do some testing.
|
||||
if __name__ == "__main__":
|
||||
#sched = SinusoidalWeightScheduler(1, .1, 50, 10)
|
||||
sched = LinearDecayWeightScheduler(1, 150, .1, 20)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
for s in range(200):
|
||||
x.append(s)
|
||||
y.append(sched.get_weight_for_step(s))
|
||||
plt.plot(x, y)
|
||||
plt.show()
|
|
@ -8,6 +8,7 @@ import models.lr_scheduler as lr_scheduler
|
|||
from models.base_model import BaseModel
|
||||
from models.loss import GANLoss
|
||||
from apex import amp
|
||||
from data.weight_scheduler import get_scheduler_for_opt
|
||||
import torch.nn.functional as F
|
||||
import glob
|
||||
import random
|
||||
|
@ -61,7 +62,22 @@ class SRGANModel(BaseModel):
|
|||
self.cri_pix = None
|
||||
|
||||
# G feature loss
|
||||
if train_opt['feature_weight'] > 0:
|
||||
if train_opt['feature_weight'] and train_opt['feature_weight'] > 0:
|
||||
# For backwards compatibility, use a scheduler definition instead. Remove this at some point.
|
||||
l_fea_type = train_opt['feature_criterion']
|
||||
if l_fea_type == 'l1':
|
||||
self.cri_fea = nn.L1Loss().to(self.device)
|
||||
elif l_fea_type == 'l2':
|
||||
self.cri_fea = nn.MSELoss().to(self.device)
|
||||
else:
|
||||
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
||||
sched_params = {
|
||||
'type': 'fixed',
|
||||
'weight': train_opt['feature_weight']
|
||||
}
|
||||
self.l_fea_sched = get_scheduler_for_opt(sched_params)
|
||||
elif train_opt['feature_scheduler']:
|
||||
self.l_fea_sched = get_scheduler_for_opt(train_opt['feature_scheduler'])
|
||||
l_fea_type = train_opt['feature_criterion']
|
||||
if l_fea_type == 'l1':
|
||||
self.cri_fea = nn.L1Loss().to(self.device)
|
||||
|
@ -69,13 +85,6 @@ class SRGANModel(BaseModel):
|
|||
self.cri_fea = nn.MSELoss().to(self.device)
|
||||
else:
|
||||
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
||||
self.l_fea_w = train_opt['feature_weight']
|
||||
self.l_fea_w_start = train_opt['feature_weight']
|
||||
self.l_fea_w_decay_start = train_opt['feature_weight_decay_start']
|
||||
self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps']
|
||||
self.l_fea_w_minimum = train_opt['feature_weight_minimum']
|
||||
if self.l_fea_w_decay_start:
|
||||
self.l_fea_w_decay_step_size = (self.l_fea_w - self.l_fea_w_minimum) / (self.l_fea_w_decay_steps)
|
||||
else:
|
||||
logger.info('Remove feature loss.')
|
||||
self.cri_fea = None
|
||||
|
@ -283,19 +292,15 @@ class SRGANModel(BaseModel):
|
|||
if self.cri_fea: # feature loss
|
||||
real_fea = self.netF(pix).detach()
|
||||
fake_fea = self.netF(fea_GenOut)
|
||||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_fea_log = l_g_fea / self.l_fea_w
|
||||
fea_w = self.l_fea_sched.get_weight_for_step(step)
|
||||
l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_fea_log = l_g_fea / fea_w
|
||||
l_g_total += l_g_fea
|
||||
|
||||
if _profile:
|
||||
print("Fea forward %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
|
||||
# in the resultant image.
|
||||
if self.l_fea_w_decay_start and step > self.l_fea_w_decay_start:
|
||||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w_start - self.l_fea_w_decay_step_size * (step - self.l_fea_w_decay_start))
|
||||
|
||||
# Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931
|
||||
# Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is
|
||||
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
|
||||
|
@ -507,7 +512,7 @@ class SRGANModel(BaseModel):
|
|||
if self.cri_pix:
|
||||
self.add_log_entry('l_g_pix', l_g_pix_log.item())
|
||||
if self.cri_fea:
|
||||
self.add_log_entry('feature_weight', self.l_fea_w)
|
||||
self.add_log_entry('feature_weight', fea_w)
|
||||
self.add_log_entry('l_g_fea', l_g_fea_log.item())
|
||||
if self.l_gan_w > 0:
|
||||
self.add_log_entry('l_g_gan', l_g_gan_log.item())
|
||||
|
|
Loading…
Reference in New Issue
Block a user