Add a FID evaluator for stylegan with structural guidance

This commit is contained in:
James Betker 2020-11-14 20:16:07 -07:00
parent c9258e2da3
commit 125cb16dce
3 changed files with 58 additions and 1 deletions

View File

@ -1,3 +1,4 @@
from models.eval.sr_style import SrStyleTransferEvaluator
from models.eval.style import StyleTransferEvaluator from models.eval.style import StyleTransferEvaluator
@ -5,5 +6,7 @@ def create_evaluator(model, opt_eval, env):
type = opt_eval['type'] type = opt_eval['type']
if type == 'style_transfer': if type == 'style_transfer':
return StyleTransferEvaluator(model, opt_eval, env) return StyleTransferEvaluator(model, opt_eval, env)
elif type == 'sr_stylegan':
return SrStyleTransferEvaluator(model, opt_eval, env)
else: else:
raise NotImplementedError() raise NotImplementedError()

View File

@ -0,0 +1,54 @@
import os
import torch
import os.path as osp
import torchvision
from torch.utils.data import BatchSampler
import models.eval.evaluator as evaluator
from pytorch_fid import fid_score
# Evaluate that feeds a LR structure into the input, then calculates a FID score on the results added to
# the interpolated LR structure.
from data.stylegan2_dataset import Stylegan2Dataset
class SrStyleTransferEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
self.batches_per_eval = opt_eval['batches_per_eval']
self.batch_sz = opt_eval['batch_size']
self.im_sz = opt_eval['image_size']
self.scale = opt_eval['scale']
self.fid_real_samples = opt_eval['real_fid_path']
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
self.dataset = Stylegan2Dataset({'path': self.fid_real_samples,
'target_size': self.im_sz,
'aug_prob': 0,
'transparent': False})
self.sampler = BatchSampler(self.dataset, self.batch_sz, False)
def perform_eval(self):
fid_fake_path = osp.join(self.env['base_path'], "..", "fid_fake", str(self.env["step"]))
os.makedirs(fid_fake_path, exist_ok=True)
fid_real_path = osp.join(self.env['base_path'], "..", "fid_real", str(self.env["step"]))
os.makedirs(fid_real_path, exist_ok=True)
counter = 0
for batch in self.sampler:
noise = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
batch_hq = [e['GT'] for e in batch]
batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device'])
resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area")
gen = self.model(noise, resized_batch)
if not isinstance(gen, list) and not isinstance(gen, tuple):
gen = [gen]
gen = gen[self.gen_output_index]
out = gen + torch.nn.functional.interpolate(resized_batch, scale_factor=self.scale, mode='bilinear')
for b in range(self.batch_sz):
torchvision.utils.save_image(out[b], osp.join(fid_fake_path, "%i_.png" % (counter)))
torchvision.utils.save_image(batch_hq[b], osp.join(fid_real_path, "%i_.png" % (counter)))
counter += 1
return {"fid": fid_score.calculate_fid_given_paths([fid_real_path, fid_fake_path], self.batch_sz, True,
2048)}

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_faster.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_celebA_separated_disc.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()