From e992e18767beb9deda4d49f36d08de64c5f54644 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 1 Jan 2021 11:59:36 -0700 Subject: [PATCH] Add initial_stride term to style_sr Also fix fid and a networks.py issue. --- codes/models/improve_rrdb/styled_sr.py | 21 +++++++-------- codes/trainer/eval/sr_fid.py | 37 ++++++++++++++++++++++++-- codes/trainer/networks.py | 8 +++--- 3 files changed, 49 insertions(+), 17 deletions(-) diff --git a/codes/models/improve_rrdb/styled_sr.py b/codes/models/improve_rrdb/styled_sr.py index e970dec3..dea13bc2 100644 --- a/codes/models/improve_rrdb/styled_sr.py +++ b/codes/models/improve_rrdb/styled_sr.py @@ -9,7 +9,7 @@ from models.RRDBNet_arch import RRDB from models.arch_util import ConvGnLelu, default_init_weights from models.stylegan.stylegan2_lucidrains import StyleVectorizer, GeneratorBlock, Conv2DMod, leaky_relu, Blur from trainer.networks import register_model -from utils.util import checkpoint +from utils.util import checkpoint, opt_get class EncoderRRDB(nn.Module): @@ -35,10 +35,10 @@ class EncoderRRDB(nn.Module): class StyledSrEncoder(nn.Module): - def __init__(self, fea_out=256): + def __init__(self, fea_out=256, initial_stride=1): super().__init__() # Current assumes fea_out=256. - self.initial_conv = ConvGnLelu(3, 32, kernel_size=7, norm=False, activation=False, bias=True) + self.initial_conv = ConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True) self.rrdbs = nn.ModuleList([ EncoderRRDB(32), EncoderRRDB(64), @@ -56,7 +56,7 @@ class StyledSrEncoder(nn.Module): class Generator(nn.Module): - def __init__(self, image_size, latent_dim, transparent=False, start_level=3, upsample_levels=2): + def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2): super().__init__() total_levels = upsample_levels + 1 # The first level handles the raw encoder output and doesn't upsample. self.image_size = image_size @@ -75,7 +75,7 @@ class Generator(nn.Module): 8, # 1024x1024 ] - self.encoder = StyledSrEncoder(filters[start_level]) + self.encoder = StyledSrEncoder(filters[start_level], initial_stride) in_out_pairs = list(zip(filters[:-1], filters[1:])) self.blocks = nn.ModuleList([]) @@ -88,8 +88,7 @@ class Generator(nn.Module): in_chan, out_chan, upsample=not_first, - upsample_rgb=not_last, - rgba=transparent + upsample_rgb=not_last ) self.blocks.append(block) @@ -108,10 +107,10 @@ class Generator(nn.Module): class StyledSrGenerator(nn.Module): - def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1): + def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1): super().__init__() self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp) - self.gen = Generator(image_size=image_size, latent_dim=latent_dim) + self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride) self.mixed_prob = .9 self._init_weights() @@ -160,5 +159,5 @@ if __name__ == '__main__': @register_model -def register_opt_styled_sr(opt_net, opt): - return StyledSrGenerator(128) +def register_styled_sr(opt_net, opt): + return StyledSrGenerator(128, initial_stride=opt_get(opt_net, ['initial_stride'], 1)) diff --git a/codes/trainer/eval/sr_fid.py b/codes/trainer/eval/sr_fid.py index d20393ad..b5e6e097 100644 --- a/codes/trainer/eval/sr_fid.py +++ b/codes/trainer/eval/sr_fid.py @@ -12,7 +12,8 @@ from data import create_dataset from torch.utils.data import DataLoader -# Computes the SR FID score for a network. +# Computes the SR FID score for a network, which is a FID score that attempts to account for structural changes the +# generator might make from the source image. class SrFidEvaluator(evaluator.Evaluator): def __init__(self, model, opt_eval, env): super().__init__(model, opt_eval, env) @@ -26,7 +27,7 @@ class SrFidEvaluator(evaluator.Evaluator): self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 def perform_eval(self): - fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"])) + fid_fake_path = osp.join(self.env['base_path'], "..", "sr_fid", str(self.env["step"])) os.makedirs(fid_fake_path, exist_ok=True) counter = 0 for batch in tqdm(self.dataloader): @@ -49,3 +50,35 @@ class SrFidEvaluator(evaluator.Evaluator): return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True, 2048)} + + +# A "normal" FID computation from a generator that takes LR inputs. Does not account for structural differences at all. +class FidForStructuralNetsEvaluator(evaluator.Evaluator): + def __init__(self, model, opt_eval, env): + super().__init__(model, opt_eval, env) + self.batch_sz = opt_eval['batch_size'] + assert self.batch_sz is not None + self.dataset = create_dataset(opt_eval['dataset']) + self.scale = opt_eval['scale'] + self.fid_real_samples = opt_eval['dataset']['paths'] # This is assumed to exist for the given dataset. + assert isinstance(self.fid_real_samples, str) + self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1) + self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 + + def perform_eval(self): + fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"])) + os.makedirs(fid_fake_path, exist_ok=True) + counter = 0 + for batch in tqdm(self.dataloader): + lq = batch['lq'].to(self.env['device']) + gen = self.model(lq) + if not isinstance(gen, list) and not isinstance(gen, tuple): + gen = [gen] + gen = gen[self.gen_output_index] + + for b in range(self.batch_sz): + torchvision.utils.save_image(gen[b], osp.join(fid_fake_path, "%i_.png" % (counter))) + counter += 1 + + return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True, + 2048)} \ No newline at end of file diff --git a/codes/trainer/networks.py b/codes/trainer/networks.py index 422d6f91..3a192f6f 100644 --- a/codes/trainer/networks.py +++ b/codes/trainer/networks.py @@ -129,10 +129,10 @@ def define_D_net(opt_net, img_sz=None, wrap=False): netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128) elif which_model == "stylegan2_discriminator": attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] - disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn) - netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) - elif which_model == "rrdb_disc": - netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3) + from models.stylegan.stylegan2_lucidrains import StyleGan2Discriminator + disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn) + from models.stylegan.stylegan2_lucidrains import StyleGan2Augmentor + netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD