forked from mrq/DL-Art-School
Add sr_fid evaluator
This commit is contained in:
parent
b1fb82476b
commit
8f0984cacf
|
@ -7,7 +7,7 @@ from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from data.data_sampler import DistIterSampler
|
from data.data_sampler import DistIterSampler
|
||||||
from trainer.eval import create_evaluator
|
from trainer.eval.evaluator import create_evaluator
|
||||||
|
|
||||||
from utils import util, options as option
|
from utils import util, options as option
|
||||||
from data import create_dataloader, create_dataset
|
from data import create_dataloader, create_dataset
|
||||||
|
@ -293,7 +293,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_6bl_opt.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_23bl_opt.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
from trainer.eval.flow_gaussian_nll import FlowGaussianNll
|
|
||||||
from trainer.eval.sr_style import SrStyleTransferEvaluator
|
|
||||||
from trainer.eval.style import StyleTransferEvaluator
|
|
||||||
|
|
||||||
|
|
||||||
def create_evaluator(model, opt_eval, env):
|
|
||||||
type = opt_eval['type']
|
|
||||||
if type == 'style_transfer':
|
|
||||||
return StyleTransferEvaluator(model, opt_eval, env)
|
|
||||||
elif type == 'sr_stylegan':
|
|
||||||
return SrStyleTransferEvaluator(model, opt_eval, env)
|
|
||||||
elif type == 'flownet_gaussian':
|
|
||||||
return FlowGaussianNll(model, opt_eval, env)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
|
@ -1,4 +1,11 @@
|
||||||
# Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response.
|
# Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response.
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import pkgutil
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
def __init__(self, model, opt_eval, env):
|
def __init__(self, model, opt_eval, env):
|
||||||
self.model = model.module if hasattr(model, 'module') else model
|
self.model = model.module if hasattr(model, 'module') else model
|
||||||
|
@ -7,3 +14,43 @@ class Evaluator:
|
||||||
|
|
||||||
def perform_eval(self):
|
def perform_eval(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def format_evaluator_name(name):
|
||||||
|
# Formats by converting from CamelCase to snake_case and removing trailing "_injector"
|
||||||
|
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
||||||
|
name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
|
||||||
|
return name.replace("_evaluator", "")
|
||||||
|
|
||||||
|
|
||||||
|
# Works by loading all python modules in the eval/ directory and sniffing out subclasses of Evaluator.
|
||||||
|
def find_registered_evaluators(base_path="trainer/eval"):
|
||||||
|
module_iter = pkgutil.walk_packages([base_path])
|
||||||
|
results = {}
|
||||||
|
for mod in module_iter:
|
||||||
|
if mod.ispkg:
|
||||||
|
EXCLUSION_LIST = []
|
||||||
|
if mod.name not in EXCLUSION_LIST:
|
||||||
|
results.update(find_registered_evaluators(f'{base_path}/{mod.name}'))
|
||||||
|
else:
|
||||||
|
mod_name = f'{base_path}/{mod.name}'.replace('/', '.')
|
||||||
|
importlib.import_module(mod_name)
|
||||||
|
classes = inspect.getmembers(sys.modules[mod_name], inspect.isclass)
|
||||||
|
for name, obj in classes:
|
||||||
|
if 'Evaluator' in [mro.__name__ for mro in inspect.getmro(obj)]:
|
||||||
|
results[format_evaluator_name(name)] = obj
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class CreateEvaluatorError(Exception):
|
||||||
|
def __init__(self, name, available):
|
||||||
|
super().__init__(f'Could not find the specified evaluator name: {name}. Available evaluators:'
|
||||||
|
f'{available}')
|
||||||
|
|
||||||
|
|
||||||
|
def create_evaluator(model, opt_eval, env):
|
||||||
|
evaluators = find_registered_evaluators()
|
||||||
|
type = opt_eval['type']
|
||||||
|
if type not in evaluators.keys():
|
||||||
|
raise CreateEvaluatorError(type, list(evaluators.keys()))
|
||||||
|
return evaluators[opt_eval['type']](model, opt_eval, env)
|
||||||
|
|
51
codes/trainer/eval/sr_fid.py
Normal file
51
codes/trainer/eval/sr_fid.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import os.path as osp
|
||||||
|
import torchvision
|
||||||
|
from torch.nn.functional import interpolate
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import trainer.eval.evaluator as evaluator
|
||||||
|
|
||||||
|
from pytorch_fid import fid_score
|
||||||
|
from data import create_dataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
# Computes the SR FID score for a network.
|
||||||
|
class SrFidEvaluator(evaluator.Evaluator):
|
||||||
|
def __init__(self, model, opt_eval, env):
|
||||||
|
super().__init__(model, opt_eval, env)
|
||||||
|
self.batch_sz = opt_eval['batch_size']
|
||||||
|
assert self.batch_sz is not None
|
||||||
|
self.dataset = create_dataset(opt_eval['dataset'])
|
||||||
|
self.scale = opt_eval['scale']
|
||||||
|
self.fid_real_samples = opt_eval['dataset']['paths'] # This is assumed to exist for the given dataset.
|
||||||
|
assert isinstance(self.fid_real_samples, str)
|
||||||
|
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1)
|
||||||
|
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||||
|
|
||||||
|
def perform_eval(self):
|
||||||
|
fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
|
||||||
|
os.makedirs(fid_fake_path, exist_ok=True)
|
||||||
|
counter = 0
|
||||||
|
for batch in tqdm(self.dataloader):
|
||||||
|
lq = batch['lq'].to(self.env['device'])
|
||||||
|
gen = self.model(lq)
|
||||||
|
if not isinstance(gen, list) and not isinstance(gen, tuple):
|
||||||
|
gen = [gen]
|
||||||
|
gen = gen[self.gen_output_index]
|
||||||
|
|
||||||
|
# Remove low-frequency differences
|
||||||
|
gen_lf = interpolate(interpolate(gen, scale_factor=1/self.scale, mode="area"), scale_factor=self.scale,
|
||||||
|
mode="nearest")
|
||||||
|
gen_hf = gen - gen_lf
|
||||||
|
hq_lf = interpolate(lq, scale_factor=self.scale, mode="nearest")
|
||||||
|
hq_gen_hf_applied = hq_lf + gen_hf
|
||||||
|
|
||||||
|
for b in range(self.batch_sz):
|
||||||
|
torchvision.utils.save_image(hq_gen_hf_applied[b], osp.join(fid_fake_path, "%i_.png" % (counter)))
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,
|
||||||
|
2048)}
|
|
@ -29,7 +29,7 @@ def format_injector_name(name):
|
||||||
return name.replace("_injector", "")
|
return name.replace("_injector", "")
|
||||||
|
|
||||||
|
|
||||||
# Works by loading all python modules in the injectors/ directory. The end result of this will be that the Injector.__subclasses__()
|
# Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector.
|
||||||
# field will be properly populated.
|
# field will be properly populated.
|
||||||
def find_registered_injectors(base_path="trainer/injectors"):
|
def find_registered_injectors(base_path="trainer/injectors"):
|
||||||
module_iter = pkgutil.walk_packages([base_path])
|
module_iter = pkgutil.walk_packages([base_path])
|
||||||
|
@ -61,4 +61,4 @@ def create_injector(opt_inject, env):
|
||||||
type = opt_inject['type']
|
type = opt_inject['type']
|
||||||
if type not in injectors.keys():
|
if type not in injectors.keys():
|
||||||
raise CreateInjectorError(type, list(injectors.keys()))
|
raise CreateInjectorError(type, list(injectors.keys()))
|
||||||
return injectors[opt_inject['type']](opt_inject, env)
|
return injectors[opt_inject['type']](opt_inject, env)
|
||||||
|
|
|
@ -55,7 +55,10 @@ def checkpoint(fn, *args):
|
||||||
return fn(*args)
|
return fn(*args)
|
||||||
|
|
||||||
def sequential_checkpoint(fn, partitions, *args):
|
def sequential_checkpoint(fn, partitions, *args):
|
||||||
enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
if loaded_options is None:
|
||||||
|
enabled = False
|
||||||
|
else:
|
||||||
|
enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
||||||
if enabled:
|
if enabled:
|
||||||
return torch.utils.checkpoint.checkpoint_sequential(fn, partitions, *args)
|
return torch.utils.checkpoint.checkpoint_sequential(fn, partitions, *args)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user