From 6a75bd0777ece2a182367d0d572675c51787f042 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 14 Jun 2021 09:51:44 -0600 Subject: [PATCH] Another fix --- codes/train.py | 2 +- codes/trainer/eval/categorization_loss_eval.py | 2 +- codes/trainer/eval/evaluator.py | 4 ++-- codes/trainer/eval/fid.py | 2 +- codes/trainer/eval/flow_gaussian_nll.py | 2 +- codes/trainer/eval/single_point_pair_contrastive_eval.py | 2 +- codes/trainer/eval/sr_fid.py | 2 +- codes/trainer/eval/sr_style.py | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/codes/train.py b/codes/train.py index 242cf515..f95e4baf 100644 --- a/codes/train.py +++ b/codes/train.py @@ -260,7 +260,7 @@ class Trainer: if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0: eval_dict = {} for eval in self.evaluators: - if eval.uses_all_ddp() or self.rank <= 0: + if eval.uses_all_ddp or self.rank <= 0: eval_dict.update(eval.perform_eval()) if self.rank <= 0: print("Evaluator results: ", eval_dict) diff --git a/codes/trainer/eval/categorization_loss_eval.py b/codes/trainer/eval/categorization_loss_eval.py index fcd009e3..2270a76d 100644 --- a/codes/trainer/eval/categorization_loss_eval.py +++ b/codes/trainer/eval/categorization_loss_eval.py @@ -13,7 +13,7 @@ from utils.util import opt_get class CategorizationLossEvaluator(evaluator.Evaluator): def __init__(self, model, opt_eval, env): - super().__init__(model, opt_eval, env) + super().__init__(model, opt_eval, env, uses_all_ddp=False) self.batch_sz = opt_eval['batch_size'] assert self.batch_sz is not None normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], diff --git a/codes/trainer/eval/evaluator.py b/codes/trainer/eval/evaluator.py index 519403c3..f1a87461 100644 --- a/codes/trainer/eval/evaluator.py +++ b/codes/trainer/eval/evaluator.py @@ -7,11 +7,11 @@ import sys class Evaluator: - def __init__(self, model, opt_eval, env): + def __init__(self, model, opt_eval, env, uses_all_ddp=True): self.model = model.module if hasattr(model, 'module') else model self.opt = opt_eval self.env = env - self.uses_all_ddp = opt_get(opt_eval, ['uses_all_ddp'], True) + self.uses_all_ddp = uses_all_ddp def perform_eval(self): return {} diff --git a/codes/trainer/eval/fid.py b/codes/trainer/eval/fid.py index f125fae7..3a5ccc07 100644 --- a/codes/trainer/eval/fid.py +++ b/codes/trainer/eval/fid.py @@ -10,7 +10,7 @@ from utils.util import opt_get # Evaluator that generates uniform noise to feed into a generator, then calculates a FID score on the results. class StyleTransferEvaluator(evaluator.Evaluator): def __init__(self, model, opt_eval, env): - super().__init__(model, opt_eval, env) + super().__init__(model, opt_eval, env, uses_all_ddp=False) self.batches_per_eval = opt_eval['batches_per_eval'] self.batch_sz = opt_eval['batch_size'] self.im_sz = opt_eval['image_size'] diff --git a/codes/trainer/eval/flow_gaussian_nll.py b/codes/trainer/eval/flow_gaussian_nll.py index dacf51d1..d55ffd7c 100644 --- a/codes/trainer/eval/flow_gaussian_nll.py +++ b/codes/trainer/eval/flow_gaussian_nll.py @@ -11,7 +11,7 @@ from models.srflow.flow import GaussianDiag class FlowGaussianNll(evaluator.Evaluator): def __init__(self, model, opt_eval, env): - super().__init__(model, opt_eval, env) + super().__init__(model, opt_eval, env, uses_all_ddp=False) self.batch_sz = opt_eval['batch_size'] self.dataset = ImageFolderDataset(opt_eval['dataset']) self.dataloader = DataLoader(self.dataset, self.batch_sz) diff --git a/codes/trainer/eval/single_point_pair_contrastive_eval.py b/codes/trainer/eval/single_point_pair_contrastive_eval.py index 8cc0278f..d73ce6cd 100644 --- a/codes/trainer/eval/single_point_pair_contrastive_eval.py +++ b/codes/trainer/eval/single_point_pair_contrastive_eval.py @@ -20,7 +20,7 @@ from utils.util import opt_get # dissimilar points remain constant or get further apart. class SinglePointPairContrastiveEval(evaluator.Evaluator): def __init__(self, model, opt_eval, env): - super().__init__(model, opt_eval, env) + super().__init__(model, opt_eval, env, uses_all_ddp=False) self.batch_sz = opt_eval['batch_size'] self.eval_qty = opt_eval['quantity'] assert self.eval_qty % self.batch_sz == 0 diff --git a/codes/trainer/eval/sr_fid.py b/codes/trainer/eval/sr_fid.py index ad70f47e..ef27877d 100644 --- a/codes/trainer/eval/sr_fid.py +++ b/codes/trainer/eval/sr_fid.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader # generator might make from the source image. class SrFidEvaluator(evaluator.Evaluator): def __init__(self, model, opt_eval, env): - super().__init__(model, opt_eval, env) + super().__init__(model, opt_eval, env, uses_all_ddp=False) self.batch_sz = opt_eval['batch_size'] assert self.batch_sz is not None self.dataset = create_dataset(opt_eval['dataset']) diff --git a/codes/trainer/eval/sr_style.py b/codes/trainer/eval/sr_style.py index 3f8f6791..add84616 100644 --- a/codes/trainer/eval/sr_style.py +++ b/codes/trainer/eval/sr_style.py @@ -16,7 +16,7 @@ from data.stylegan2_dataset import Stylegan2Dataset class SrStyleTransferEvaluator(evaluator.Evaluator): def __init__(self, model, opt_eval, env): - super().__init__(model, opt_eval, env) + super().__init__(model, opt_eval, env, uses_all_ddp=False) self.batches_per_eval = opt_eval['batches_per_eval'] self.batch_sz = opt_eval['batch_size'] self.im_sz = opt_eval['image_size']