From c50cce2a623e1ebf137263762db0162ed3c9c6c8 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 23 Jul 2020 17:03:54 -0600 Subject: [PATCH] Add an abstract, configurabler weight scheduling class and apply it to the feature weight --- codes/data/weight_scheduler.py | 64 ++++++++++++++++++++++++++++++++++ codes/models/SRGAN_model.py | 37 +++++++++++--------- 2 files changed, 85 insertions(+), 16 deletions(-) create mode 100644 codes/data/weight_scheduler.py diff --git a/codes/data/weight_scheduler.py b/codes/data/weight_scheduler.py new file mode 100644 index 00000000..b0b8cfc7 --- /dev/null +++ b/codes/data/weight_scheduler.py @@ -0,0 +1,64 @@ +import math +from matplotlib import pyplot as plt + +# Base class for weight schedulers. Holds weight at a fixed initial value. +class WeightScheduler: + def __init__(self, initial_weight): + self.initial_weight = initial_weight + + def get_weight_for_step(self, step): + return self.initial_weight + + +class LinearDecayWeightScheduler(WeightScheduler): + def __init__(self, initial_weight, steps_to_decay, lower_bound, initial_step=0): + super(LinearDecayWeightScheduler, self).__init__(initial_weight) + self.steps_to_decay = steps_to_decay + self.lower_bound = lower_bound + self.initial_step = initial_step + self.decrease_per_step = (initial_weight - lower_bound) / self.steps_to_decay + + def get_weight_for_step(self, step): + step = step - self.initial_step + if step < 0: + return self.initial_weight + return max(self.lower_bound, self.initial_weight - step * self.decrease_per_step) + + +class SinusoidalWeightScheduler(WeightScheduler): + def __init__(self, upper_weight, lower_weight, period_steps, initial_step=0): + super(SinusoidalWeightScheduler, self).__init__(upper_weight) + self.center = (upper_weight + lower_weight) / 2 + self.amplitude = (upper_weight - lower_weight) / 2 + self.period = period_steps + self.initial_step = initial_step + + def get_weight_for_step(self, step): + step = step - self.initial_step + if step < 0: + return self.initial_weight + # Use cosine because it starts at y=1 for x=0. + return math.cos(step * math.pi * 2 / self.period) * self.amplitude + self.center + + +def get_scheduler_for_opt(opt): + if opt['type'] == 'fixed': + return WeightScheduler(opt['weight']) + elif opt['type'] == 'linear_decay': + return LinearDecayWeightScheduler(opt['initial_weight'], opt['steps'], opt['lower_bound'], opt['start_step']) + elif opt['type'] == 'sinusoidal': + return SinusoidalWeightScheduler(opt['upper_weight'], opt['lower_weight'], opt['period'], opt['start_step']) + + +# Do some testing. +if __name__ == "__main__": + #sched = SinusoidalWeightScheduler(1, .1, 50, 10) + sched = LinearDecayWeightScheduler(1, 150, .1, 20) + + x = [] + y = [] + for s in range(200): + x.append(s) + y.append(sched.get_weight_for_step(s)) + plt.plot(x, y) + plt.show() \ No newline at end of file diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index d88fa46c..4bcdaa94 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -8,6 +8,7 @@ import models.lr_scheduler as lr_scheduler from models.base_model import BaseModel from models.loss import GANLoss from apex import amp +from data.weight_scheduler import get_scheduler_for_opt import torch.nn.functional as F import glob import random @@ -61,7 +62,22 @@ class SRGANModel(BaseModel): self.cri_pix = None # G feature loss - if train_opt['feature_weight'] > 0: + if train_opt['feature_weight'] and train_opt['feature_weight'] > 0: + # For backwards compatibility, use a scheduler definition instead. Remove this at some point. + l_fea_type = train_opt['feature_criterion'] + if l_fea_type == 'l1': + self.cri_fea = nn.L1Loss().to(self.device) + elif l_fea_type == 'l2': + self.cri_fea = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) + sched_params = { + 'type': 'fixed', + 'weight': train_opt['feature_weight'] + } + self.l_fea_sched = get_scheduler_for_opt(sched_params) + elif train_opt['feature_scheduler']: + self.l_fea_sched = get_scheduler_for_opt(train_opt['feature_scheduler']) l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) @@ -69,13 +85,6 @@ class SRGANModel(BaseModel): self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) - self.l_fea_w = train_opt['feature_weight'] - self.l_fea_w_start = train_opt['feature_weight'] - self.l_fea_w_decay_start = train_opt['feature_weight_decay_start'] - self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps'] - self.l_fea_w_minimum = train_opt['feature_weight_minimum'] - if self.l_fea_w_decay_start: - self.l_fea_w_decay_step_size = (self.l_fea_w - self.l_fea_w_minimum) / (self.l_fea_w_decay_steps) else: logger.info('Remove feature loss.') self.cri_fea = None @@ -283,19 +292,15 @@ class SRGANModel(BaseModel): if self.cri_fea: # feature loss real_fea = self.netF(pix).detach() fake_fea = self.netF(fea_GenOut) - l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) - l_g_fea_log = l_g_fea / self.l_fea_w + fea_w = self.l_fea_sched.get_weight_for_step(step) + l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea) + l_g_fea_log = l_g_fea / fea_w l_g_total += l_g_fea if _profile: print("Fea forward %f" % (time() - _t,)) _t = time() - # Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role - # in the resultant image. - if self.l_fea_w_decay_start and step > self.l_fea_w_decay_start: - self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w_start - self.l_fea_w_decay_step_size * (step - self.l_fea_w_decay_start)) - # Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931 # Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is # equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically, @@ -507,7 +512,7 @@ class SRGANModel(BaseModel): if self.cri_pix: self.add_log_entry('l_g_pix', l_g_pix_log.item()) if self.cri_fea: - self.add_log_entry('feature_weight', self.l_fea_w) + self.add_log_entry('feature_weight', fea_w) self.add_log_entry('l_g_fea', l_g_fea_log.item()) if self.l_gan_w > 0: self.add_log_entry('l_g_gan', l_g_gan_log.item())