forked from mrq/DL-Art-School
Another fix
This commit is contained in:
parent
54bff35171
commit
6a75bd0777
|
@ -260,7 +260,7 @@ class Trainer:
|
||||||
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
|
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
|
||||||
eval_dict = {}
|
eval_dict = {}
|
||||||
for eval in self.evaluators:
|
for eval in self.evaluators:
|
||||||
if eval.uses_all_ddp() or self.rank <= 0:
|
if eval.uses_all_ddp or self.rank <= 0:
|
||||||
eval_dict.update(eval.perform_eval())
|
eval_dict.update(eval.perform_eval())
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
print("Evaluator results: ", eval_dict)
|
print("Evaluator results: ", eval_dict)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from utils.util import opt_get
|
||||||
|
|
||||||
class CategorizationLossEvaluator(evaluator.Evaluator):
|
class CategorizationLossEvaluator(evaluator.Evaluator):
|
||||||
def __init__(self, model, opt_eval, env):
|
def __init__(self, model, opt_eval, env):
|
||||||
super().__init__(model, opt_eval, env)
|
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||||
self.batch_sz = opt_eval['batch_size']
|
self.batch_sz = opt_eval['batch_size']
|
||||||
assert self.batch_sz is not None
|
assert self.batch_sz is not None
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
|
|
|
@ -7,11 +7,11 @@ import sys
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
def __init__(self, model, opt_eval, env):
|
def __init__(self, model, opt_eval, env, uses_all_ddp=True):
|
||||||
self.model = model.module if hasattr(model, 'module') else model
|
self.model = model.module if hasattr(model, 'module') else model
|
||||||
self.opt = opt_eval
|
self.opt = opt_eval
|
||||||
self.env = env
|
self.env = env
|
||||||
self.uses_all_ddp = opt_get(opt_eval, ['uses_all_ddp'], True)
|
self.uses_all_ddp = uses_all_ddp
|
||||||
|
|
||||||
def perform_eval(self):
|
def perform_eval(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
|
@ -10,7 +10,7 @@ from utils.util import opt_get
|
||||||
# Evaluator that generates uniform noise to feed into a generator, then calculates a FID score on the results.
|
# Evaluator that generates uniform noise to feed into a generator, then calculates a FID score on the results.
|
||||||
class StyleTransferEvaluator(evaluator.Evaluator):
|
class StyleTransferEvaluator(evaluator.Evaluator):
|
||||||
def __init__(self, model, opt_eval, env):
|
def __init__(self, model, opt_eval, env):
|
||||||
super().__init__(model, opt_eval, env)
|
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||||
self.batches_per_eval = opt_eval['batches_per_eval']
|
self.batches_per_eval = opt_eval['batches_per_eval']
|
||||||
self.batch_sz = opt_eval['batch_size']
|
self.batch_sz = opt_eval['batch_size']
|
||||||
self.im_sz = opt_eval['image_size']
|
self.im_sz = opt_eval['image_size']
|
||||||
|
|
|
@ -11,7 +11,7 @@ from models.srflow.flow import GaussianDiag
|
||||||
|
|
||||||
class FlowGaussianNll(evaluator.Evaluator):
|
class FlowGaussianNll(evaluator.Evaluator):
|
||||||
def __init__(self, model, opt_eval, env):
|
def __init__(self, model, opt_eval, env):
|
||||||
super().__init__(model, opt_eval, env)
|
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||||
self.batch_sz = opt_eval['batch_size']
|
self.batch_sz = opt_eval['batch_size']
|
||||||
self.dataset = ImageFolderDataset(opt_eval['dataset'])
|
self.dataset = ImageFolderDataset(opt_eval['dataset'])
|
||||||
self.dataloader = DataLoader(self.dataset, self.batch_sz)
|
self.dataloader = DataLoader(self.dataset, self.batch_sz)
|
||||||
|
|
|
@ -20,7 +20,7 @@ from utils.util import opt_get
|
||||||
# dissimilar points remain constant or get further apart.
|
# dissimilar points remain constant or get further apart.
|
||||||
class SinglePointPairContrastiveEval(evaluator.Evaluator):
|
class SinglePointPairContrastiveEval(evaluator.Evaluator):
|
||||||
def __init__(self, model, opt_eval, env):
|
def __init__(self, model, opt_eval, env):
|
||||||
super().__init__(model, opt_eval, env)
|
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||||
self.batch_sz = opt_eval['batch_size']
|
self.batch_sz = opt_eval['batch_size']
|
||||||
self.eval_qty = opt_eval['quantity']
|
self.eval_qty = opt_eval['quantity']
|
||||||
assert self.eval_qty % self.batch_sz == 0
|
assert self.eval_qty % self.batch_sz == 0
|
||||||
|
|
|
@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
|
||||||
# generator might make from the source image.
|
# generator might make from the source image.
|
||||||
class SrFidEvaluator(evaluator.Evaluator):
|
class SrFidEvaluator(evaluator.Evaluator):
|
||||||
def __init__(self, model, opt_eval, env):
|
def __init__(self, model, opt_eval, env):
|
||||||
super().__init__(model, opt_eval, env)
|
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||||
self.batch_sz = opt_eval['batch_size']
|
self.batch_sz = opt_eval['batch_size']
|
||||||
assert self.batch_sz is not None
|
assert self.batch_sz is not None
|
||||||
self.dataset = create_dataset(opt_eval['dataset'])
|
self.dataset = create_dataset(opt_eval['dataset'])
|
||||||
|
|
|
@ -16,7 +16,7 @@ from data.stylegan2_dataset import Stylegan2Dataset
|
||||||
|
|
||||||
class SrStyleTransferEvaluator(evaluator.Evaluator):
|
class SrStyleTransferEvaluator(evaluator.Evaluator):
|
||||||
def __init__(self, model, opt_eval, env):
|
def __init__(self, model, opt_eval, env):
|
||||||
super().__init__(model, opt_eval, env)
|
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||||
self.batches_per_eval = opt_eval['batches_per_eval']
|
self.batches_per_eval = opt_eval['batches_per_eval']
|
||||||
self.batch_sz = opt_eval['batch_size']
|
self.batch_sz = opt_eval['batch_size']
|
||||||
self.im_sz = opt_eval['image_size']
|
self.im_sz = opt_eval['image_size']
|
||||||
|
|
Loading…
Reference in New Issue
Block a user