Add separate Evaluator module and FID evaluator

This commit is contained in:
James Betker 2020-11-13 11:03:54 -07:00
parent 080ad61be4
commit a07e1a7292
6 changed files with 88 additions and 1 deletions

View File

@ -562,7 +562,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
def forward(self, x):
b, f, h, w = x.shape
full_random_latents = False
full_random_latents = True
if full_random_latents:
style = self.noise(b*2, self.gen.latent_dim, x.device)
w = self.vectorizer(style)

View File

@ -0,0 +1,9 @@
from models.eval.style import StyleTransferEvaluator
def create_evaluator(model, opt_eval, env):
type = opt_eval['type']
if type == 'style_transfer':
return StyleTransferEvaluator(model, opt_eval, env)
else:
raise NotImplementedError()

View File

@ -0,0 +1,9 @@
# Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response.
class Evaluator:
def __init__(self, model, opt_eval, env):
self.model = model
self.opt = opt_eval
self.env = env
def perform_eval(self):
return {}

View File

@ -0,0 +1,35 @@
import os
import torch
import os.path as osp
import torchvision
import models.eval.evaluator as evaluator
from pytorch_fid import fid_score
# Evaluate that generates uniform noise to feed into a generator, then calculates a FID score on the results.
class StyleTransferEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
self.batches_per_eval = opt_eval['batches_per_eval']
self.batch_sz = opt_eval['batch_size']
self.im_sz = opt_eval['image_size']
self.fid_real_samples = opt_eval['real_fid_path']
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
def perform_eval(self):
fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
os.makedirs(fid_fake_path, exist_ok=True)
counter = 0
for i in range(self.batches_per_eval):
batch = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
gen = self.model(batch)
if not isinstance(gen, list) and not isinstance(gen, tuple):
gen = [gen]
gen = gen[self.gen_output_index]
for b in range(self.batch_sz):
torchvision.utils.save_image(gen[b], osp.join(fid_fake_path, "%i_.png" % (counter)))
counter += 1
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,
2048)}

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch
from data.data_sampler import DistIterSampler
from models.eval import create_evaluator
from utils import util, options as option
from data import create_dataloader, create_dataset
@ -129,6 +130,13 @@ class Trainer:
#### create model
self.model = ExtensibleTrainer(opt, cached_networks=all_networks)
### Evaluators
self.evaluators = []
if 'evaluators' in opt['eval'].keys():
for ev_key, ev_opt in opt['eval']['evaluators'].items():
self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']],
ev_opt, self.model.env))
#### resume training
if resume_state:
self.logger.info('Resuming training from epoch: {}, iter: {}.'.format(
@ -241,11 +249,20 @@ class Trainer:
# log
self.logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name'] and self.rank <= 0:
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
eval_dict = {}
for eval in self.evaluators:
eval_dict.update(eval.perform_eval())
print("Evaluator results: ", eval_dict)
for ek, ev in eval_dict.items():
self.tb_logger.add_scalar(ek, ev, self.current_step)
def do_training(self):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1):

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch
from data.data_sampler import DistIterSampler
from models.eval import create_evaluator
from utils import util, options as option
from data import create_dataloader, create_dataset
@ -129,6 +130,13 @@ class Trainer:
#### create model
self.model = ExtensibleTrainer(opt, cached_networks=all_networks)
### Evaluators
self.evaluators = []
if 'evaluators' in opt['eval'].keys():
for ev_key, ev_opt in opt['eval']['evaluators'].items():
self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']],
ev_opt, self.model.env))
#### resume training
if resume_state:
self.logger.info('Resuming training from epoch: {}, iter: {}.'.format(
@ -241,11 +249,20 @@ class Trainer:
# log
self.logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name'] and self.rank <= 0:
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
eval_dict = {}
for eval in self.evaluators:
eval_dict.update(eval.perform_eval())
print("Evaluator results: ", eval_dict)
for ek, ev in eval_dict.items():
self.tb_logger.add_scalar(ek, ev, self.current_step)
def do_training(self):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1):