forked from mrq/DL-Art-School
Add separate Evaluator module and FID evaluator
This commit is contained in:
parent
080ad61be4
commit
a07e1a7292
|
@ -562,7 +562,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
|||
def forward(self, x):
|
||||
b, f, h, w = x.shape
|
||||
|
||||
full_random_latents = False
|
||||
full_random_latents = True
|
||||
if full_random_latents:
|
||||
style = self.noise(b*2, self.gen.latent_dim, x.device)
|
||||
w = self.vectorizer(style)
|
||||
|
|
9
codes/models/eval/__init__.py
Normal file
9
codes/models/eval/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
from models.eval.style import StyleTransferEvaluator
|
||||
|
||||
|
||||
def create_evaluator(model, opt_eval, env):
|
||||
type = opt_eval['type']
|
||||
if type == 'style_transfer':
|
||||
return StyleTransferEvaluator(model, opt_eval, env)
|
||||
else:
|
||||
raise NotImplementedError()
|
9
codes/models/eval/evaluator.py
Normal file
9
codes/models/eval/evaluator.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
# Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response.
|
||||
class Evaluator:
|
||||
def __init__(self, model, opt_eval, env):
|
||||
self.model = model
|
||||
self.opt = opt_eval
|
||||
self.env = env
|
||||
|
||||
def perform_eval(self):
|
||||
return {}
|
35
codes/models/eval/style.py
Normal file
35
codes/models/eval/style.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import os.path as osp
|
||||
import torchvision
|
||||
import models.eval.evaluator as evaluator
|
||||
from pytorch_fid import fid_score
|
||||
|
||||
|
||||
# Evaluate that generates uniform noise to feed into a generator, then calculates a FID score on the results.
|
||||
class StyleTransferEvaluator(evaluator.Evaluator):
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env)
|
||||
self.batches_per_eval = opt_eval['batches_per_eval']
|
||||
self.batch_sz = opt_eval['batch_size']
|
||||
self.im_sz = opt_eval['image_size']
|
||||
self.fid_real_samples = opt_eval['real_fid_path']
|
||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||
|
||||
def perform_eval(self):
|
||||
fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
|
||||
os.makedirs(fid_fake_path, exist_ok=True)
|
||||
counter = 0
|
||||
for i in range(self.batches_per_eval):
|
||||
batch = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
|
||||
gen = self.model(batch)
|
||||
if not isinstance(gen, list) and not isinstance(gen, tuple):
|
||||
gen = [gen]
|
||||
gen = gen[self.gen_output_index]
|
||||
for b in range(self.batch_sz):
|
||||
torchvision.utils.save_image(gen[b], osp.join(fid_fake_path, "%i_.png" % (counter)))
|
||||
counter += 1
|
||||
|
||||
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,
|
||||
2048)}
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
|||
|
||||
import torch
|
||||
from data.data_sampler import DistIterSampler
|
||||
from models.eval import create_evaluator
|
||||
|
||||
from utils import util, options as option
|
||||
from data import create_dataloader, create_dataset
|
||||
|
@ -129,6 +130,13 @@ class Trainer:
|
|||
#### create model
|
||||
self.model = ExtensibleTrainer(opt, cached_networks=all_networks)
|
||||
|
||||
### Evaluators
|
||||
self.evaluators = []
|
||||
if 'evaluators' in opt['eval'].keys():
|
||||
for ev_key, ev_opt in opt['eval']['evaluators'].items():
|
||||
self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']],
|
||||
ev_opt, self.model.env))
|
||||
|
||||
#### resume training
|
||||
if resume_state:
|
||||
self.logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
||||
|
@ -241,11 +249,20 @@ class Trainer:
|
|||
|
||||
# log
|
||||
self.logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
||||
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name'] and self.rank <= 0:
|
||||
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
|
||||
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
|
||||
|
||||
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
|
||||
eval_dict = {}
|
||||
for eval in self.evaluators:
|
||||
eval_dict.update(eval.perform_eval())
|
||||
print("Evaluator results: ", eval_dict)
|
||||
for ek, ev in eval_dict.items():
|
||||
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
||||
|
||||
def do_training(self):
|
||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
||||
|
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
|||
|
||||
import torch
|
||||
from data.data_sampler import DistIterSampler
|
||||
from models.eval import create_evaluator
|
||||
|
||||
from utils import util, options as option
|
||||
from data import create_dataloader, create_dataset
|
||||
|
@ -129,6 +130,13 @@ class Trainer:
|
|||
#### create model
|
||||
self.model = ExtensibleTrainer(opt, cached_networks=all_networks)
|
||||
|
||||
### Evaluators
|
||||
self.evaluators = []
|
||||
if 'evaluators' in opt['eval'].keys():
|
||||
for ev_key, ev_opt in opt['eval']['evaluators'].items():
|
||||
self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']],
|
||||
ev_opt, self.model.env))
|
||||
|
||||
#### resume training
|
||||
if resume_state:
|
||||
self.logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
||||
|
@ -241,11 +249,20 @@ class Trainer:
|
|||
|
||||
# log
|
||||
self.logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
||||
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name'] and self.rank <= 0:
|
||||
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
|
||||
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
|
||||
|
||||
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
|
||||
eval_dict = {}
|
||||
for eval in self.evaluators:
|
||||
eval_dict.update(eval.perform_eval())
|
||||
print("Evaluator results: ", eval_dict)
|
||||
for ek, ev in eval_dict.items():
|
||||
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
||||
|
||||
def do_training(self):
|
||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
||||
|
|
Loading…
Reference in New Issue
Block a user