From a07e1a7292afb65a95525453f3e16e34ed89c7c1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 13 Nov 2020 11:03:54 -0700 Subject: [PATCH] Add separate Evaluator module and FID evaluator --- codes/models/archs/stylegan2.py | 2 +- codes/models/eval/__init__.py | 9 +++++++++ codes/models/eval/evaluator.py | 9 +++++++++ codes/models/eval/style.py | 35 +++++++++++++++++++++++++++++++++ codes/train.py | 17 ++++++++++++++++ codes/train2.py | 17 ++++++++++++++++ 6 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 codes/models/eval/__init__.py create mode 100644 codes/models/eval/evaluator.py create mode 100644 codes/models/eval/style.py diff --git a/codes/models/archs/stylegan2.py b/codes/models/archs/stylegan2.py index ed5bd38b..30d36bb3 100644 --- a/codes/models/archs/stylegan2.py +++ b/codes/models/archs/stylegan2.py @@ -562,7 +562,7 @@ class StyleGan2GeneratorWithLatent(nn.Module): def forward(self, x): b, f, h, w = x.shape - full_random_latents = False + full_random_latents = True if full_random_latents: style = self.noise(b*2, self.gen.latent_dim, x.device) w = self.vectorizer(style) diff --git a/codes/models/eval/__init__.py b/codes/models/eval/__init__.py new file mode 100644 index 00000000..77daf558 --- /dev/null +++ b/codes/models/eval/__init__.py @@ -0,0 +1,9 @@ +from models.eval.style import StyleTransferEvaluator + + +def create_evaluator(model, opt_eval, env): + type = opt_eval['type'] + if type == 'style_transfer': + return StyleTransferEvaluator(model, opt_eval, env) + else: + raise NotImplementedError() \ No newline at end of file diff --git a/codes/models/eval/evaluator.py b/codes/models/eval/evaluator.py new file mode 100644 index 00000000..6e8c665b --- /dev/null +++ b/codes/models/eval/evaluator.py @@ -0,0 +1,9 @@ +# Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response. +class Evaluator: + def __init__(self, model, opt_eval, env): + self.model = model + self.opt = opt_eval + self.env = env + + def perform_eval(self): + return {} \ No newline at end of file diff --git a/codes/models/eval/style.py b/codes/models/eval/style.py new file mode 100644 index 00000000..b0c9773d --- /dev/null +++ b/codes/models/eval/style.py @@ -0,0 +1,35 @@ +import os + +import torch +import os.path as osp +import torchvision +import models.eval.evaluator as evaluator +from pytorch_fid import fid_score + + +# Evaluate that generates uniform noise to feed into a generator, then calculates a FID score on the results. +class StyleTransferEvaluator(evaluator.Evaluator): + def __init__(self, model, opt_eval, env): + super().__init__(model, opt_eval, env) + self.batches_per_eval = opt_eval['batches_per_eval'] + self.batch_sz = opt_eval['batch_size'] + self.im_sz = opt_eval['image_size'] + self.fid_real_samples = opt_eval['real_fid_path'] + self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 + + def perform_eval(self): + fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"])) + os.makedirs(fid_fake_path, exist_ok=True) + counter = 0 + for i in range(self.batches_per_eval): + batch = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device']) + gen = self.model(batch) + if not isinstance(gen, list) and not isinstance(gen, tuple): + gen = [gen] + gen = gen[self.gen_output_index] + for b in range(self.batch_sz): + torchvision.utils.save_image(gen[b], osp.join(fid_fake_path, "%i_.png" % (counter))) + counter += 1 + + return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True, + 2048)} diff --git a/codes/train.py b/codes/train.py index ec5211f6..27d34b47 100644 --- a/codes/train.py +++ b/codes/train.py @@ -7,6 +7,7 @@ from tqdm import tqdm import torch from data.data_sampler import DistIterSampler +from models.eval import create_evaluator from utils import util, options as option from data import create_dataloader, create_dataset @@ -129,6 +130,13 @@ class Trainer: #### create model self.model = ExtensibleTrainer(opt, cached_networks=all_networks) + ### Evaluators + self.evaluators = [] + if 'evaluators' in opt['eval'].keys(): + for ev_key, ev_opt in opt['eval']['evaluators'].items(): + self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']], + ev_opt, self.model.env)) + #### resume training if resume_state: self.logger.info('Resuming training from epoch: {}, iter: {}.'.format( @@ -241,11 +249,20 @@ class Trainer: # log self.logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss)) + # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name'] and self.rank <= 0: self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step) self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step) + if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0: + eval_dict = {} + for eval in self.evaluators: + eval_dict.update(eval.perform_eval()) + print("Evaluator results: ", eval_dict) + for ek, ev in eval_dict.items(): + self.tb_logger.add_scalar(ek, ev, self.current_step) + def do_training(self): self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) for epoch in range(self.start_epoch, self.total_epochs + 1): diff --git a/codes/train2.py b/codes/train2.py index ec5211f6..27d34b47 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -7,6 +7,7 @@ from tqdm import tqdm import torch from data.data_sampler import DistIterSampler +from models.eval import create_evaluator from utils import util, options as option from data import create_dataloader, create_dataset @@ -129,6 +130,13 @@ class Trainer: #### create model self.model = ExtensibleTrainer(opt, cached_networks=all_networks) + ### Evaluators + self.evaluators = [] + if 'evaluators' in opt['eval'].keys(): + for ev_key, ev_opt in opt['eval']['evaluators'].items(): + self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']], + ev_opt, self.model.env)) + #### resume training if resume_state: self.logger.info('Resuming training from epoch: {}, iter: {}.'.format( @@ -241,11 +249,20 @@ class Trainer: # log self.logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss)) + # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name'] and self.rank <= 0: self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step) self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step) + if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0: + eval_dict = {} + for eval in self.evaluators: + eval_dict.update(eval.perform_eval()) + print("Evaluator results: ", eval_dict) + for ek, ev in eval_dict.items(): + self.tb_logger.add_scalar(ek, ev, self.current_step) + def do_training(self): self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) for epoch in range(self.start_epoch, self.total_epochs + 1):