From 125cb16dce67159f8bdafe5dc0567804b8e75d01 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 14 Nov 2020 20:16:07 -0700 Subject: [PATCH] Add a FID evaluator for stylegan with structural guidance --- codes/models/eval/__init__.py | 3 ++ codes/models/eval/sr_style.py | 54 +++++++++++++++++++++++++++++++++++ codes/train.py | 2 +- 3 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 codes/models/eval/sr_style.py diff --git a/codes/models/eval/__init__.py b/codes/models/eval/__init__.py index 77daf558..111ca71f 100644 --- a/codes/models/eval/__init__.py +++ b/codes/models/eval/__init__.py @@ -1,3 +1,4 @@ +from models.eval.sr_style import SrStyleTransferEvaluator from models.eval.style import StyleTransferEvaluator @@ -5,5 +6,7 @@ def create_evaluator(model, opt_eval, env): type = opt_eval['type'] if type == 'style_transfer': return StyleTransferEvaluator(model, opt_eval, env) + elif type == 'sr_stylegan': + return SrStyleTransferEvaluator(model, opt_eval, env) else: raise NotImplementedError() \ No newline at end of file diff --git a/codes/models/eval/sr_style.py b/codes/models/eval/sr_style.py new file mode 100644 index 00000000..ca46fae8 --- /dev/null +++ b/codes/models/eval/sr_style.py @@ -0,0 +1,54 @@ +import os + +import torch +import os.path as osp +import torchvision +from torch.utils.data import BatchSampler + +import models.eval.evaluator as evaluator +from pytorch_fid import fid_score + + +# Evaluate that feeds a LR structure into the input, then calculates a FID score on the results added to +# the interpolated LR structure. +from data.stylegan2_dataset import Stylegan2Dataset + + +class SrStyleTransferEvaluator(evaluator.Evaluator): + def __init__(self, model, opt_eval, env): + super().__init__(model, opt_eval, env) + self.batches_per_eval = opt_eval['batches_per_eval'] + self.batch_sz = opt_eval['batch_size'] + self.im_sz = opt_eval['image_size'] + self.scale = opt_eval['scale'] + self.fid_real_samples = opt_eval['real_fid_path'] + self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 + self.dataset = Stylegan2Dataset({'path': self.fid_real_samples, + 'target_size': self.im_sz, + 'aug_prob': 0, + 'transparent': False}) + self.sampler = BatchSampler(self.dataset, self.batch_sz, False) + + def perform_eval(self): + fid_fake_path = osp.join(self.env['base_path'], "..", "fid_fake", str(self.env["step"])) + os.makedirs(fid_fake_path, exist_ok=True) + fid_real_path = osp.join(self.env['base_path'], "..", "fid_real", str(self.env["step"])) + os.makedirs(fid_real_path, exist_ok=True) + counter = 0 + for batch in self.sampler: + noise = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device']) + batch_hq = [e['GT'] for e in batch] + batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device']) + resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area") + gen = self.model(noise, resized_batch) + if not isinstance(gen, list) and not isinstance(gen, tuple): + gen = [gen] + gen = gen[self.gen_output_index] + out = gen + torch.nn.functional.interpolate(resized_batch, scale_factor=self.scale, mode='bilinear') + for b in range(self.batch_sz): + torchvision.utils.save_image(out[b], osp.join(fid_fake_path, "%i_.png" % (counter))) + torchvision.utils.save_image(batch_hq[b], osp.join(fid_real_path, "%i_.png" % (counter))) + counter += 1 + + return {"fid": fid_score.calculate_fid_given_paths([fid_real_path, fid_fake_path], self.batch_sz, True, + 2048)} diff --git a/codes/train.py b/codes/train.py index 3ecc1e04..2aff1dc9 100644 --- a/codes/train.py +++ b/codes/train.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_faster.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_celebA_separated_disc.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()