From 8f0984cacf3cc022a08514611e41dce0bd0030d6 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 30 Dec 2020 20:18:58 -0700 Subject: [PATCH] Add sr_fid evaluator --- codes/train.py | 4 +-- codes/trainer/eval/__init__.py | 15 ---------- codes/trainer/eval/evaluator.py | 47 ++++++++++++++++++++++++++++++ codes/trainer/eval/sr_fid.py | 51 +++++++++++++++++++++++++++++++++ codes/trainer/inject.py | 4 +-- codes/utils/util.py | 5 +++- 6 files changed, 106 insertions(+), 20 deletions(-) create mode 100644 codes/trainer/eval/sr_fid.py diff --git a/codes/train.py b/codes/train.py index 5ad1271f..93045042 100644 --- a/codes/train.py +++ b/codes/train.py @@ -7,7 +7,7 @@ from tqdm import tqdm import torch from data.data_sampler import DistIterSampler -from trainer.eval import create_evaluator +from trainer.eval.evaluator import create_evaluator from utils import util, options as option from data import create_dataloader, create_dataset @@ -293,7 +293,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_6bl_opt.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_23bl_opt.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/eval/__init__.py b/codes/trainer/eval/__init__.py index e8338096..e69de29b 100644 --- a/codes/trainer/eval/__init__.py +++ b/codes/trainer/eval/__init__.py @@ -1,15 +0,0 @@ -from trainer.eval.flow_gaussian_nll import FlowGaussianNll -from trainer.eval.sr_style import SrStyleTransferEvaluator -from trainer.eval.style import StyleTransferEvaluator - - -def create_evaluator(model, opt_eval, env): - type = opt_eval['type'] - if type == 'style_transfer': - return StyleTransferEvaluator(model, opt_eval, env) - elif type == 'sr_stylegan': - return SrStyleTransferEvaluator(model, opt_eval, env) - elif type == 'flownet_gaussian': - return FlowGaussianNll(model, opt_eval, env) - else: - raise NotImplementedError() \ No newline at end of file diff --git a/codes/trainer/eval/evaluator.py b/codes/trainer/eval/evaluator.py index 5f0a364f..5fd76d21 100644 --- a/codes/trainer/eval/evaluator.py +++ b/codes/trainer/eval/evaluator.py @@ -1,4 +1,11 @@ # Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response. +import importlib +import inspect +import pkgutil +import re +import sys + + class Evaluator: def __init__(self, model, opt_eval, env): self.model = model.module if hasattr(model, 'module') else model @@ -7,3 +14,43 @@ class Evaluator: def perform_eval(self): return {} + + +def format_evaluator_name(name): + # Formats by converting from CamelCase to snake_case and removing trailing "_injector" + name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower() + return name.replace("_evaluator", "") + + +# Works by loading all python modules in the eval/ directory and sniffing out subclasses of Evaluator. +def find_registered_evaluators(base_path="trainer/eval"): + module_iter = pkgutil.walk_packages([base_path]) + results = {} + for mod in module_iter: + if mod.ispkg: + EXCLUSION_LIST = [] + if mod.name not in EXCLUSION_LIST: + results.update(find_registered_evaluators(f'{base_path}/{mod.name}')) + else: + mod_name = f'{base_path}/{mod.name}'.replace('/', '.') + importlib.import_module(mod_name) + classes = inspect.getmembers(sys.modules[mod_name], inspect.isclass) + for name, obj in classes: + if 'Evaluator' in [mro.__name__ for mro in inspect.getmro(obj)]: + results[format_evaluator_name(name)] = obj + return results + + +class CreateEvaluatorError(Exception): + def __init__(self, name, available): + super().__init__(f'Could not find the specified evaluator name: {name}. Available evaluators:' + f'{available}') + + +def create_evaluator(model, opt_eval, env): + evaluators = find_registered_evaluators() + type = opt_eval['type'] + if type not in evaluators.keys(): + raise CreateEvaluatorError(type, list(evaluators.keys())) + return evaluators[opt_eval['type']](model, opt_eval, env) diff --git a/codes/trainer/eval/sr_fid.py b/codes/trainer/eval/sr_fid.py new file mode 100644 index 00000000..d20393ad --- /dev/null +++ b/codes/trainer/eval/sr_fid.py @@ -0,0 +1,51 @@ +import os +import torch +import os.path as osp +import torchvision +from torch.nn.functional import interpolate +from tqdm import tqdm + +import trainer.eval.evaluator as evaluator + +from pytorch_fid import fid_score +from data import create_dataset +from torch.utils.data import DataLoader + + +# Computes the SR FID score for a network. +class SrFidEvaluator(evaluator.Evaluator): + def __init__(self, model, opt_eval, env): + super().__init__(model, opt_eval, env) + self.batch_sz = opt_eval['batch_size'] + assert self.batch_sz is not None + self.dataset = create_dataset(opt_eval['dataset']) + self.scale = opt_eval['scale'] + self.fid_real_samples = opt_eval['dataset']['paths'] # This is assumed to exist for the given dataset. + assert isinstance(self.fid_real_samples, str) + self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1) + self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 + + def perform_eval(self): + fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"])) + os.makedirs(fid_fake_path, exist_ok=True) + counter = 0 + for batch in tqdm(self.dataloader): + lq = batch['lq'].to(self.env['device']) + gen = self.model(lq) + if not isinstance(gen, list) and not isinstance(gen, tuple): + gen = [gen] + gen = gen[self.gen_output_index] + + # Remove low-frequency differences + gen_lf = interpolate(interpolate(gen, scale_factor=1/self.scale, mode="area"), scale_factor=self.scale, + mode="nearest") + gen_hf = gen - gen_lf + hq_lf = interpolate(lq, scale_factor=self.scale, mode="nearest") + hq_gen_hf_applied = hq_lf + gen_hf + + for b in range(self.batch_sz): + torchvision.utils.save_image(hq_gen_hf_applied[b], osp.join(fid_fake_path, "%i_.png" % (counter))) + counter += 1 + + return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True, + 2048)} diff --git a/codes/trainer/inject.py b/codes/trainer/inject.py index 56ed4026..77da4f1e 100644 --- a/codes/trainer/inject.py +++ b/codes/trainer/inject.py @@ -29,7 +29,7 @@ def format_injector_name(name): return name.replace("_injector", "") -# Works by loading all python modules in the injectors/ directory. The end result of this will be that the Injector.__subclasses__() +# Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector. # field will be properly populated. def find_registered_injectors(base_path="trainer/injectors"): module_iter = pkgutil.walk_packages([base_path]) @@ -61,4 +61,4 @@ def create_injector(opt_inject, env): type = opt_inject['type'] if type not in injectors.keys(): raise CreateInjectorError(type, list(injectors.keys())) - return injectors[opt_inject['type']](opt_inject, env) \ No newline at end of file + return injectors[opt_inject['type']](opt_inject, env) diff --git a/codes/utils/util.py b/codes/utils/util.py index 146a880c..cf42ccb9 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -55,7 +55,10 @@ def checkpoint(fn, *args): return fn(*args) def sequential_checkpoint(fn, partitions, *args): - enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True + if loaded_options is None: + enabled = False + else: + enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True if enabled: return torch.utils.checkpoint.checkpoint_sequential(fn, partitions, *args) else: