forked from mrq/DL-Art-School
Another fix
This commit is contained in:
parent
54bff35171
commit
6a75bd0777
|
@ -260,7 +260,7 @@ class Trainer:
|
|||
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
|
||||
eval_dict = {}
|
||||
for eval in self.evaluators:
|
||||
if eval.uses_all_ddp() or self.rank <= 0:
|
||||
if eval.uses_all_ddp or self.rank <= 0:
|
||||
eval_dict.update(eval.perform_eval())
|
||||
if self.rank <= 0:
|
||||
print("Evaluator results: ", eval_dict)
|
||||
|
|
|
@ -13,7 +13,7 @@ from utils.util import opt_get
|
|||
|
||||
class CategorizationLossEvaluator(evaluator.Evaluator):
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env)
|
||||
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||
self.batch_sz = opt_eval['batch_size']
|
||||
assert self.batch_sz is not None
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
|
|
|
@ -7,11 +7,11 @@ import sys
|
|||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, model, opt_eval, env):
|
||||
def __init__(self, model, opt_eval, env, uses_all_ddp=True):
|
||||
self.model = model.module if hasattr(model, 'module') else model
|
||||
self.opt = opt_eval
|
||||
self.env = env
|
||||
self.uses_all_ddp = opt_get(opt_eval, ['uses_all_ddp'], True)
|
||||
self.uses_all_ddp = uses_all_ddp
|
||||
|
||||
def perform_eval(self):
|
||||
return {}
|
||||
|
|
|
@ -10,7 +10,7 @@ from utils.util import opt_get
|
|||
# Evaluator that generates uniform noise to feed into a generator, then calculates a FID score on the results.
|
||||
class StyleTransferEvaluator(evaluator.Evaluator):
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env)
|
||||
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||
self.batches_per_eval = opt_eval['batches_per_eval']
|
||||
self.batch_sz = opt_eval['batch_size']
|
||||
self.im_sz = opt_eval['image_size']
|
||||
|
|
|
@ -11,7 +11,7 @@ from models.srflow.flow import GaussianDiag
|
|||
|
||||
class FlowGaussianNll(evaluator.Evaluator):
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env)
|
||||
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||
self.batch_sz = opt_eval['batch_size']
|
||||
self.dataset = ImageFolderDataset(opt_eval['dataset'])
|
||||
self.dataloader = DataLoader(self.dataset, self.batch_sz)
|
||||
|
|
|
@ -20,7 +20,7 @@ from utils.util import opt_get
|
|||
# dissimilar points remain constant or get further apart.
|
||||
class SinglePointPairContrastiveEval(evaluator.Evaluator):
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env)
|
||||
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||
self.batch_sz = opt_eval['batch_size']
|
||||
self.eval_qty = opt_eval['quantity']
|
||||
assert self.eval_qty % self.batch_sz == 0
|
||||
|
|
|
@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
|
|||
# generator might make from the source image.
|
||||
class SrFidEvaluator(evaluator.Evaluator):
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env)
|
||||
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||
self.batch_sz = opt_eval['batch_size']
|
||||
assert self.batch_sz is not None
|
||||
self.dataset = create_dataset(opt_eval['dataset'])
|
||||
|
|
|
@ -16,7 +16,7 @@ from data.stylegan2_dataset import Stylegan2Dataset
|
|||
|
||||
class SrStyleTransferEvaluator(evaluator.Evaluator):
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env)
|
||||
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||
self.batches_per_eval = opt_eval['batches_per_eval']
|
||||
self.batch_sz = opt_eval['batch_size']
|
||||
self.im_sz = opt_eval['image_size']
|
||||
|
|
Loading…
Reference in New Issue
Block a user