forked from mrq/DL-Art-School
Add sr_fid evaluator
This commit is contained in:
parent
b1fb82476b
commit
8f0984cacf
|
@ -7,7 +7,7 @@ from tqdm import tqdm
|
|||
|
||||
import torch
|
||||
from data.data_sampler import DistIterSampler
|
||||
from trainer.eval import create_evaluator
|
||||
from trainer.eval.evaluator import create_evaluator
|
||||
|
||||
from utils import util, options as option
|
||||
from data import create_dataloader, create_dataset
|
||||
|
@ -293,7 +293,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_6bl_opt.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_23bl_opt.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
from trainer.eval.flow_gaussian_nll import FlowGaussianNll
|
||||
from trainer.eval.sr_style import SrStyleTransferEvaluator
|
||||
from trainer.eval.style import StyleTransferEvaluator
|
||||
|
||||
|
||||
def create_evaluator(model, opt_eval, env):
|
||||
type = opt_eval['type']
|
||||
if type == 'style_transfer':
|
||||
return StyleTransferEvaluator(model, opt_eval, env)
|
||||
elif type == 'sr_stylegan':
|
||||
return SrStyleTransferEvaluator(model, opt_eval, env)
|
||||
elif type == 'flownet_gaussian':
|
||||
return FlowGaussianNll(model, opt_eval, env)
|
||||
else:
|
||||
raise NotImplementedError()
|
|
@ -1,4 +1,11 @@
|
|||
# Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response.
|
||||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, model, opt_eval, env):
|
||||
self.model = model.module if hasattr(model, 'module') else model
|
||||
|
@ -7,3 +14,43 @@ class Evaluator:
|
|||
|
||||
def perform_eval(self):
|
||||
return {}
|
||||
|
||||
|
||||
def format_evaluator_name(name):
|
||||
# Formats by converting from CamelCase to snake_case and removing trailing "_injector"
|
||||
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
||||
name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
|
||||
return name.replace("_evaluator", "")
|
||||
|
||||
|
||||
# Works by loading all python modules in the eval/ directory and sniffing out subclasses of Evaluator.
|
||||
def find_registered_evaluators(base_path="trainer/eval"):
|
||||
module_iter = pkgutil.walk_packages([base_path])
|
||||
results = {}
|
||||
for mod in module_iter:
|
||||
if mod.ispkg:
|
||||
EXCLUSION_LIST = []
|
||||
if mod.name not in EXCLUSION_LIST:
|
||||
results.update(find_registered_evaluators(f'{base_path}/{mod.name}'))
|
||||
else:
|
||||
mod_name = f'{base_path}/{mod.name}'.replace('/', '.')
|
||||
importlib.import_module(mod_name)
|
||||
classes = inspect.getmembers(sys.modules[mod_name], inspect.isclass)
|
||||
for name, obj in classes:
|
||||
if 'Evaluator' in [mro.__name__ for mro in inspect.getmro(obj)]:
|
||||
results[format_evaluator_name(name)] = obj
|
||||
return results
|
||||
|
||||
|
||||
class CreateEvaluatorError(Exception):
|
||||
def __init__(self, name, available):
|
||||
super().__init__(f'Could not find the specified evaluator name: {name}. Available evaluators:'
|
||||
f'{available}')
|
||||
|
||||
|
||||
def create_evaluator(model, opt_eval, env):
|
||||
evaluators = find_registered_evaluators()
|
||||
type = opt_eval['type']
|
||||
if type not in evaluators.keys():
|
||||
raise CreateEvaluatorError(type, list(evaluators.keys()))
|
||||
return evaluators[opt_eval['type']](model, opt_eval, env)
|
||||
|
|
51
codes/trainer/eval/sr_fid.py
Normal file
51
codes/trainer/eval/sr_fid.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
import os
|
||||
import torch
|
||||
import os.path as osp
|
||||
import torchvision
|
||||
from torch.nn.functional import interpolate
|
||||
from tqdm import tqdm
|
||||
|
||||
import trainer.eval.evaluator as evaluator
|
||||
|
||||
from pytorch_fid import fid_score
|
||||
from data import create_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
# Computes the SR FID score for a network.
|
||||
class SrFidEvaluator(evaluator.Evaluator):
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env)
|
||||
self.batch_sz = opt_eval['batch_size']
|
||||
assert self.batch_sz is not None
|
||||
self.dataset = create_dataset(opt_eval['dataset'])
|
||||
self.scale = opt_eval['scale']
|
||||
self.fid_real_samples = opt_eval['dataset']['paths'] # This is assumed to exist for the given dataset.
|
||||
assert isinstance(self.fid_real_samples, str)
|
||||
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1)
|
||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||
|
||||
def perform_eval(self):
|
||||
fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
|
||||
os.makedirs(fid_fake_path, exist_ok=True)
|
||||
counter = 0
|
||||
for batch in tqdm(self.dataloader):
|
||||
lq = batch['lq'].to(self.env['device'])
|
||||
gen = self.model(lq)
|
||||
if not isinstance(gen, list) and not isinstance(gen, tuple):
|
||||
gen = [gen]
|
||||
gen = gen[self.gen_output_index]
|
||||
|
||||
# Remove low-frequency differences
|
||||
gen_lf = interpolate(interpolate(gen, scale_factor=1/self.scale, mode="area"), scale_factor=self.scale,
|
||||
mode="nearest")
|
||||
gen_hf = gen - gen_lf
|
||||
hq_lf = interpolate(lq, scale_factor=self.scale, mode="nearest")
|
||||
hq_gen_hf_applied = hq_lf + gen_hf
|
||||
|
||||
for b in range(self.batch_sz):
|
||||
torchvision.utils.save_image(hq_gen_hf_applied[b], osp.join(fid_fake_path, "%i_.png" % (counter)))
|
||||
counter += 1
|
||||
|
||||
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,
|
||||
2048)}
|
|
@ -29,7 +29,7 @@ def format_injector_name(name):
|
|||
return name.replace("_injector", "")
|
||||
|
||||
|
||||
# Works by loading all python modules in the injectors/ directory. The end result of this will be that the Injector.__subclasses__()
|
||||
# Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector.
|
||||
# field will be properly populated.
|
||||
def find_registered_injectors(base_path="trainer/injectors"):
|
||||
module_iter = pkgutil.walk_packages([base_path])
|
||||
|
@ -61,4 +61,4 @@ def create_injector(opt_inject, env):
|
|||
type = opt_inject['type']
|
||||
if type not in injectors.keys():
|
||||
raise CreateInjectorError(type, list(injectors.keys()))
|
||||
return injectors[opt_inject['type']](opt_inject, env)
|
||||
return injectors[opt_inject['type']](opt_inject, env)
|
||||
|
|
|
@ -55,7 +55,10 @@ def checkpoint(fn, *args):
|
|||
return fn(*args)
|
||||
|
||||
def sequential_checkpoint(fn, partitions, *args):
|
||||
enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
||||
if loaded_options is None:
|
||||
enabled = False
|
||||
else:
|
||||
enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
||||
if enabled:
|
||||
return torch.utils.checkpoint.checkpoint_sequential(fn, partitions, *args)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue
Block a user