Another fix

This commit is contained in:
James Betker 2021-06-14 09:51:44 -06:00
parent 54bff35171
commit 6a75bd0777
8 changed files with 9 additions and 9 deletions

View File

@ -260,7 +260,7 @@ class Trainer:
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
eval_dict = {}
for eval in self.evaluators:
if eval.uses_all_ddp() or self.rank <= 0:
if eval.uses_all_ddp or self.rank <= 0:
eval_dict.update(eval.perform_eval())
if self.rank <= 0:
print("Evaluator results: ", eval_dict)

View File

@ -13,7 +13,7 @@ from utils.util import opt_get
class CategorizationLossEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
super().__init__(model, opt_eval, env, uses_all_ddp=False)
self.batch_sz = opt_eval['batch_size']
assert self.batch_sz is not None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],

View File

@ -7,11 +7,11 @@ import sys
class Evaluator:
def __init__(self, model, opt_eval, env):
def __init__(self, model, opt_eval, env, uses_all_ddp=True):
self.model = model.module if hasattr(model, 'module') else model
self.opt = opt_eval
self.env = env
self.uses_all_ddp = opt_get(opt_eval, ['uses_all_ddp'], True)
self.uses_all_ddp = uses_all_ddp
def perform_eval(self):
return {}

View File

@ -10,7 +10,7 @@ from utils.util import opt_get
# Evaluator that generates uniform noise to feed into a generator, then calculates a FID score on the results.
class StyleTransferEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
super().__init__(model, opt_eval, env, uses_all_ddp=False)
self.batches_per_eval = opt_eval['batches_per_eval']
self.batch_sz = opt_eval['batch_size']
self.im_sz = opt_eval['image_size']

View File

@ -11,7 +11,7 @@ from models.srflow.flow import GaussianDiag
class FlowGaussianNll(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
super().__init__(model, opt_eval, env, uses_all_ddp=False)
self.batch_sz = opt_eval['batch_size']
self.dataset = ImageFolderDataset(opt_eval['dataset'])
self.dataloader = DataLoader(self.dataset, self.batch_sz)

View File

@ -20,7 +20,7 @@ from utils.util import opt_get
# dissimilar points remain constant or get further apart.
class SinglePointPairContrastiveEval(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
super().__init__(model, opt_eval, env, uses_all_ddp=False)
self.batch_sz = opt_eval['batch_size']
self.eval_qty = opt_eval['quantity']
assert self.eval_qty % self.batch_sz == 0

View File

@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
# generator might make from the source image.
class SrFidEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
super().__init__(model, opt_eval, env, uses_all_ddp=False)
self.batch_sz = opt_eval['batch_size']
assert self.batch_sz is not None
self.dataset = create_dataset(opt_eval['dataset'])

View File

@ -16,7 +16,7 @@ from data.stylegan2_dataset import Stylegan2Dataset
class SrStyleTransferEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
super().__init__(model, opt_eval, env, uses_all_ddp=False)
self.batches_per_eval = opt_eval['batches_per_eval']
self.batch_sz = opt_eval['batch_size']
self.im_sz = opt_eval['image_size']