Add sr_fid evaluator

This commit is contained in:
James Betker 2020-12-30 20:18:58 -07:00
parent b1fb82476b
commit 8f0984cacf
6 changed files with 106 additions and 20 deletions

View File

@ -7,7 +7,7 @@ from tqdm import tqdm
import torch
from data.data_sampler import DistIterSampler
from trainer.eval import create_evaluator
from trainer.eval.evaluator import create_evaluator
from utils import util, options as option
from data import create_dataloader, create_dataset
@ -293,7 +293,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_6bl_opt.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_23bl_opt.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -1,15 +0,0 @@
from trainer.eval.flow_gaussian_nll import FlowGaussianNll
from trainer.eval.sr_style import SrStyleTransferEvaluator
from trainer.eval.style import StyleTransferEvaluator
def create_evaluator(model, opt_eval, env):
type = opt_eval['type']
if type == 'style_transfer':
return StyleTransferEvaluator(model, opt_eval, env)
elif type == 'sr_stylegan':
return SrStyleTransferEvaluator(model, opt_eval, env)
elif type == 'flownet_gaussian':
return FlowGaussianNll(model, opt_eval, env)
else:
raise NotImplementedError()

View File

@ -1,4 +1,11 @@
# Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response.
import importlib
import inspect
import pkgutil
import re
import sys
class Evaluator:
def __init__(self, model, opt_eval, env):
self.model = model.module if hasattr(model, 'module') else model
@ -7,3 +14,43 @@ class Evaluator:
def perform_eval(self):
return {}
def format_evaluator_name(name):
# Formats by converting from CamelCase to snake_case and removing trailing "_injector"
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
return name.replace("_evaluator", "")
# Works by loading all python modules in the eval/ directory and sniffing out subclasses of Evaluator.
def find_registered_evaluators(base_path="trainer/eval"):
module_iter = pkgutil.walk_packages([base_path])
results = {}
for mod in module_iter:
if mod.ispkg:
EXCLUSION_LIST = []
if mod.name not in EXCLUSION_LIST:
results.update(find_registered_evaluators(f'{base_path}/{mod.name}'))
else:
mod_name = f'{base_path}/{mod.name}'.replace('/', '.')
importlib.import_module(mod_name)
classes = inspect.getmembers(sys.modules[mod_name], inspect.isclass)
for name, obj in classes:
if 'Evaluator' in [mro.__name__ for mro in inspect.getmro(obj)]:
results[format_evaluator_name(name)] = obj
return results
class CreateEvaluatorError(Exception):
def __init__(self, name, available):
super().__init__(f'Could not find the specified evaluator name: {name}. Available evaluators:'
f'{available}')
def create_evaluator(model, opt_eval, env):
evaluators = find_registered_evaluators()
type = opt_eval['type']
if type not in evaluators.keys():
raise CreateEvaluatorError(type, list(evaluators.keys()))
return evaluators[opt_eval['type']](model, opt_eval, env)

View File

@ -0,0 +1,51 @@
import os
import torch
import os.path as osp
import torchvision
from torch.nn.functional import interpolate
from tqdm import tqdm
import trainer.eval.evaluator as evaluator
from pytorch_fid import fid_score
from data import create_dataset
from torch.utils.data import DataLoader
# Computes the SR FID score for a network.
class SrFidEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
self.batch_sz = opt_eval['batch_size']
assert self.batch_sz is not None
self.dataset = create_dataset(opt_eval['dataset'])
self.scale = opt_eval['scale']
self.fid_real_samples = opt_eval['dataset']['paths'] # This is assumed to exist for the given dataset.
assert isinstance(self.fid_real_samples, str)
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1)
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
def perform_eval(self):
fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
os.makedirs(fid_fake_path, exist_ok=True)
counter = 0
for batch in tqdm(self.dataloader):
lq = batch['lq'].to(self.env['device'])
gen = self.model(lq)
if not isinstance(gen, list) and not isinstance(gen, tuple):
gen = [gen]
gen = gen[self.gen_output_index]
# Remove low-frequency differences
gen_lf = interpolate(interpolate(gen, scale_factor=1/self.scale, mode="area"), scale_factor=self.scale,
mode="nearest")
gen_hf = gen - gen_lf
hq_lf = interpolate(lq, scale_factor=self.scale, mode="nearest")
hq_gen_hf_applied = hq_lf + gen_hf
for b in range(self.batch_sz):
torchvision.utils.save_image(hq_gen_hf_applied[b], osp.join(fid_fake_path, "%i_.png" % (counter)))
counter += 1
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,
2048)}

View File

@ -29,7 +29,7 @@ def format_injector_name(name):
return name.replace("_injector", "")
# Works by loading all python modules in the injectors/ directory. The end result of this will be that the Injector.__subclasses__()
# Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector.
# field will be properly populated.
def find_registered_injectors(base_path="trainer/injectors"):
module_iter = pkgutil.walk_packages([base_path])
@ -61,4 +61,4 @@ def create_injector(opt_inject, env):
type = opt_inject['type']
if type not in injectors.keys():
raise CreateInjectorError(type, list(injectors.keys()))
return injectors[opt_inject['type']](opt_inject, env)
return injectors[opt_inject['type']](opt_inject, env)

View File

@ -55,7 +55,10 @@ def checkpoint(fn, *args):
return fn(*args)
def sequential_checkpoint(fn, partitions, *args):
enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
if loaded_options is None:
enabled = False
else:
enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
if enabled:
return torch.utils.checkpoint.checkpoint_sequential(fn, partitions, *args)
else: