forked from mrq/DL-Art-School
Add a FID evaluator for stylegan with structural guidance
This commit is contained in:
parent
c9258e2da3
commit
125cb16dce
|
@ -1,3 +1,4 @@
|
|||
from models.eval.sr_style import SrStyleTransferEvaluator
|
||||
from models.eval.style import StyleTransferEvaluator
|
||||
|
||||
|
||||
|
@ -5,5 +6,7 @@ def create_evaluator(model, opt_eval, env):
|
|||
type = opt_eval['type']
|
||||
if type == 'style_transfer':
|
||||
return StyleTransferEvaluator(model, opt_eval, env)
|
||||
elif type == 'sr_stylegan':
|
||||
return SrStyleTransferEvaluator(model, opt_eval, env)
|
||||
else:
|
||||
raise NotImplementedError()
|
54
codes/models/eval/sr_style.py
Normal file
54
codes/models/eval/sr_style.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import os.path as osp
|
||||
import torchvision
|
||||
from torch.utils.data import BatchSampler
|
||||
|
||||
import models.eval.evaluator as evaluator
|
||||
from pytorch_fid import fid_score
|
||||
|
||||
|
||||
# Evaluate that feeds a LR structure into the input, then calculates a FID score on the results added to
|
||||
# the interpolated LR structure.
|
||||
from data.stylegan2_dataset import Stylegan2Dataset
|
||||
|
||||
|
||||
class SrStyleTransferEvaluator(evaluator.Evaluator):
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env)
|
||||
self.batches_per_eval = opt_eval['batches_per_eval']
|
||||
self.batch_sz = opt_eval['batch_size']
|
||||
self.im_sz = opt_eval['image_size']
|
||||
self.scale = opt_eval['scale']
|
||||
self.fid_real_samples = opt_eval['real_fid_path']
|
||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||
self.dataset = Stylegan2Dataset({'path': self.fid_real_samples,
|
||||
'target_size': self.im_sz,
|
||||
'aug_prob': 0,
|
||||
'transparent': False})
|
||||
self.sampler = BatchSampler(self.dataset, self.batch_sz, False)
|
||||
|
||||
def perform_eval(self):
|
||||
fid_fake_path = osp.join(self.env['base_path'], "..", "fid_fake", str(self.env["step"]))
|
||||
os.makedirs(fid_fake_path, exist_ok=True)
|
||||
fid_real_path = osp.join(self.env['base_path'], "..", "fid_real", str(self.env["step"]))
|
||||
os.makedirs(fid_real_path, exist_ok=True)
|
||||
counter = 0
|
||||
for batch in self.sampler:
|
||||
noise = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
|
||||
batch_hq = [e['GT'] for e in batch]
|
||||
batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device'])
|
||||
resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area")
|
||||
gen = self.model(noise, resized_batch)
|
||||
if not isinstance(gen, list) and not isinstance(gen, tuple):
|
||||
gen = [gen]
|
||||
gen = gen[self.gen_output_index]
|
||||
out = gen + torch.nn.functional.interpolate(resized_batch, scale_factor=self.scale, mode='bilinear')
|
||||
for b in range(self.batch_sz):
|
||||
torchvision.utils.save_image(out[b], osp.join(fid_fake_path, "%i_.png" % (counter)))
|
||||
torchvision.utils.save_image(batch_hq[b], osp.join(fid_real_path, "%i_.png" % (counter)))
|
||||
counter += 1
|
||||
|
||||
return {"fid": fid_score.calculate_fid_given_paths([fid_real_path, fid_fake_path], self.batch_sz, True,
|
||||
2048)}
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_faster.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_celebA_separated_disc.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user