From 886d59d5df13426d2a240de92bd9b861399900b8 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 1 Sep 2020 07:58:11 -0600 Subject: [PATCH] Misc fixes & adjustments --- codes/models/ExtensibleTrainer.py | 4 ++-- codes/models/steps/losses.py | 6 ++++-- codes/train.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index aa6dd288..29ff0859 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -54,8 +54,8 @@ class ExtensibleTrainer(BaseModel): step = ConfigurableStep(step, self.env) self.steps.append(step) - # The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped - # yet. + # step.define_optimizers() relies on the networks being placed in the env, so put them there. Even though + # they aren't wrapped yet. self.env['generators'] = self.netsG self.env['discriminators'] = self.netsD diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index b99d5f73..d4133135 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -118,6 +118,7 @@ class DiscriminatorGanLoss(ConfigurableLoss): d_real = net(state[self.opt['real']]) d_fake = net(state[self.opt['fake']].detach()) self.metrics.append(("d_fake", torch.mean(d_fake))) + self.metrics.append(("d_real", torch.mean(d_real))) if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan', 'crossgan_lrref']: l_real = self.criterion(d_real, True) @@ -129,10 +130,11 @@ class DiscriminatorGanLoss(ConfigurableLoss): l_total += l_mreal + l_mfake self.metrics.append(("l_mismatch", l_mfake + l_mreal)) self.metrics.append(("l_fake", l_fake)) + self.metrics.append(("l_real", l_real)) return l_total elif self.opt['gan_type'] == 'ragan': - return (self.cri_gan(d_real - torch.mean(d_fake), True) + - self.cri_gan(d_fake - torch.mean(d_real), False)) + return (self.criterion(d_real - torch.mean(d_fake), True) + + self.criterion(d_fake - torch.mean(d_real), False)) else: raise NotImplementedError diff --git a/codes/train.py b/codes/train.py index 661e9b92..f4014552 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_gan.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)