Add initial_stride term to style_sr

Also fix fid and a networks.py issue.
This commit is contained in:
James Betker 2021-01-01 11:59:36 -07:00
parent 9864fe4c04
commit e992e18767
3 changed files with 49 additions and 17 deletions

View File

@ -9,7 +9,7 @@ from models.RRDBNet_arch import RRDB
from models.arch_util import ConvGnLelu, default_init_weights
from models.stylegan.stylegan2_lucidrains import StyleVectorizer, GeneratorBlock, Conv2DMod, leaky_relu, Blur
from trainer.networks import register_model
from utils.util import checkpoint
from utils.util import checkpoint, opt_get
class EncoderRRDB(nn.Module):
@ -35,10 +35,10 @@ class EncoderRRDB(nn.Module):
class StyledSrEncoder(nn.Module):
def __init__(self, fea_out=256):
def __init__(self, fea_out=256, initial_stride=1):
super().__init__()
# Current assumes fea_out=256.
self.initial_conv = ConvGnLelu(3, 32, kernel_size=7, norm=False, activation=False, bias=True)
self.initial_conv = ConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True)
self.rrdbs = nn.ModuleList([
EncoderRRDB(32),
EncoderRRDB(64),
@ -56,7 +56,7 @@ class StyledSrEncoder(nn.Module):
class Generator(nn.Module):
def __init__(self, image_size, latent_dim, transparent=False, start_level=3, upsample_levels=2):
def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2):
super().__init__()
total_levels = upsample_levels + 1 # The first level handles the raw encoder output and doesn't upsample.
self.image_size = image_size
@ -75,7 +75,7 @@ class Generator(nn.Module):
8, # 1024x1024
]
self.encoder = StyledSrEncoder(filters[start_level])
self.encoder = StyledSrEncoder(filters[start_level], initial_stride)
in_out_pairs = list(zip(filters[:-1], filters[1:]))
self.blocks = nn.ModuleList([])
@ -88,8 +88,7 @@ class Generator(nn.Module):
in_chan,
out_chan,
upsample=not_first,
upsample_rgb=not_last,
rgba=transparent
upsample_rgb=not_last
)
self.blocks.append(block)
@ -108,10 +107,10 @@ class Generator(nn.Module):
class StyledSrGenerator(nn.Module):
def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1):
def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1):
super().__init__()
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)
self.gen = Generator(image_size=image_size, latent_dim=latent_dim)
self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride)
self.mixed_prob = .9
self._init_weights()
@ -160,5 +159,5 @@ if __name__ == '__main__':
@register_model
def register_opt_styled_sr(opt_net, opt):
return StyledSrGenerator(128)
def register_styled_sr(opt_net, opt):
return StyledSrGenerator(128, initial_stride=opt_get(opt_net, ['initial_stride'], 1))

View File

@ -12,7 +12,8 @@ from data import create_dataset
from torch.utils.data import DataLoader
# Computes the SR FID score for a network.
# Computes the SR FID score for a network, which is a FID score that attempts to account for structural changes the
# generator might make from the source image.
class SrFidEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
@ -26,7 +27,7 @@ class SrFidEvaluator(evaluator.Evaluator):
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
def perform_eval(self):
fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
fid_fake_path = osp.join(self.env['base_path'], "..", "sr_fid", str(self.env["step"]))
os.makedirs(fid_fake_path, exist_ok=True)
counter = 0
for batch in tqdm(self.dataloader):
@ -49,3 +50,35 @@ class SrFidEvaluator(evaluator.Evaluator):
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,
2048)}
# A "normal" FID computation from a generator that takes LR inputs. Does not account for structural differences at all.
class FidForStructuralNetsEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
self.batch_sz = opt_eval['batch_size']
assert self.batch_sz is not None
self.dataset = create_dataset(opt_eval['dataset'])
self.scale = opt_eval['scale']
self.fid_real_samples = opt_eval['dataset']['paths'] # This is assumed to exist for the given dataset.
assert isinstance(self.fid_real_samples, str)
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1)
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
def perform_eval(self):
fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
os.makedirs(fid_fake_path, exist_ok=True)
counter = 0
for batch in tqdm(self.dataloader):
lq = batch['lq'].to(self.env['device'])
gen = self.model(lq)
if not isinstance(gen, list) and not isinstance(gen, tuple):
gen = [gen]
gen = gen[self.gen_output_index]
for b in range(self.batch_sz):
torchvision.utils.save_image(gen[b], osp.join(fid_fake_path, "%i_.png" % (counter)))
counter += 1
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,
2048)}

View File

@ -129,10 +129,10 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == "stylegan2_discriminator":
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
elif which_model == "rrdb_disc":
netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3)
from models.stylegan.stylegan2_lucidrains import StyleGan2Discriminator
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
from models.stylegan.stylegan2_lucidrains import StyleGan2Augmentor
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD