Add initial_stride term to style_sr
Also fix fid and a networks.py issue.
This commit is contained in:
parent
9864fe4c04
commit
e992e18767
|
@ -9,7 +9,7 @@ from models.RRDBNet_arch import RRDB
|
||||||
from models.arch_util import ConvGnLelu, default_init_weights
|
from models.arch_util import ConvGnLelu, default_init_weights
|
||||||
from models.stylegan.stylegan2_lucidrains import StyleVectorizer, GeneratorBlock, Conv2DMod, leaky_relu, Blur
|
from models.stylegan.stylegan2_lucidrains import StyleVectorizer, GeneratorBlock, Conv2DMod, leaky_relu, Blur
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint, opt_get
|
||||||
|
|
||||||
|
|
||||||
class EncoderRRDB(nn.Module):
|
class EncoderRRDB(nn.Module):
|
||||||
|
@ -35,10 +35,10 @@ class EncoderRRDB(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class StyledSrEncoder(nn.Module):
|
class StyledSrEncoder(nn.Module):
|
||||||
def __init__(self, fea_out=256):
|
def __init__(self, fea_out=256, initial_stride=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Current assumes fea_out=256.
|
# Current assumes fea_out=256.
|
||||||
self.initial_conv = ConvGnLelu(3, 32, kernel_size=7, norm=False, activation=False, bias=True)
|
self.initial_conv = ConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True)
|
||||||
self.rrdbs = nn.ModuleList([
|
self.rrdbs = nn.ModuleList([
|
||||||
EncoderRRDB(32),
|
EncoderRRDB(32),
|
||||||
EncoderRRDB(64),
|
EncoderRRDB(64),
|
||||||
|
@ -56,7 +56,7 @@ class StyledSrEncoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, image_size, latent_dim, transparent=False, start_level=3, upsample_levels=2):
|
def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
total_levels = upsample_levels + 1 # The first level handles the raw encoder output and doesn't upsample.
|
total_levels = upsample_levels + 1 # The first level handles the raw encoder output and doesn't upsample.
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
|
@ -75,7 +75,7 @@ class Generator(nn.Module):
|
||||||
8, # 1024x1024
|
8, # 1024x1024
|
||||||
]
|
]
|
||||||
|
|
||||||
self.encoder = StyledSrEncoder(filters[start_level])
|
self.encoder = StyledSrEncoder(filters[start_level], initial_stride)
|
||||||
|
|
||||||
in_out_pairs = list(zip(filters[:-1], filters[1:]))
|
in_out_pairs = list(zip(filters[:-1], filters[1:]))
|
||||||
self.blocks = nn.ModuleList([])
|
self.blocks = nn.ModuleList([])
|
||||||
|
@ -88,8 +88,7 @@ class Generator(nn.Module):
|
||||||
in_chan,
|
in_chan,
|
||||||
out_chan,
|
out_chan,
|
||||||
upsample=not_first,
|
upsample=not_first,
|
||||||
upsample_rgb=not_last,
|
upsample_rgb=not_last
|
||||||
rgba=transparent
|
|
||||||
)
|
)
|
||||||
self.blocks.append(block)
|
self.blocks.append(block)
|
||||||
|
|
||||||
|
@ -108,10 +107,10 @@ class Generator(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class StyledSrGenerator(nn.Module):
|
class StyledSrGenerator(nn.Module):
|
||||||
def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1):
|
def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)
|
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)
|
||||||
self.gen = Generator(image_size=image_size, latent_dim=latent_dim)
|
self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride)
|
||||||
self.mixed_prob = .9
|
self.mixed_prob = .9
|
||||||
self._init_weights()
|
self._init_weights()
|
||||||
|
|
||||||
|
@ -160,5 +159,5 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_opt_styled_sr(opt_net, opt):
|
def register_styled_sr(opt_net, opt):
|
||||||
return StyledSrGenerator(128)
|
return StyledSrGenerator(128, initial_stride=opt_get(opt_net, ['initial_stride'], 1))
|
||||||
|
|
|
@ -12,7 +12,8 @@ from data import create_dataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
# Computes the SR FID score for a network.
|
# Computes the SR FID score for a network, which is a FID score that attempts to account for structural changes the
|
||||||
|
# generator might make from the source image.
|
||||||
class SrFidEvaluator(evaluator.Evaluator):
|
class SrFidEvaluator(evaluator.Evaluator):
|
||||||
def __init__(self, model, opt_eval, env):
|
def __init__(self, model, opt_eval, env):
|
||||||
super().__init__(model, opt_eval, env)
|
super().__init__(model, opt_eval, env)
|
||||||
|
@ -26,7 +27,7 @@ class SrFidEvaluator(evaluator.Evaluator):
|
||||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||||
|
|
||||||
def perform_eval(self):
|
def perform_eval(self):
|
||||||
fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
|
fid_fake_path = osp.join(self.env['base_path'], "..", "sr_fid", str(self.env["step"]))
|
||||||
os.makedirs(fid_fake_path, exist_ok=True)
|
os.makedirs(fid_fake_path, exist_ok=True)
|
||||||
counter = 0
|
counter = 0
|
||||||
for batch in tqdm(self.dataloader):
|
for batch in tqdm(self.dataloader):
|
||||||
|
@ -49,3 +50,35 @@ class SrFidEvaluator(evaluator.Evaluator):
|
||||||
|
|
||||||
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,
|
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,
|
||||||
2048)}
|
2048)}
|
||||||
|
|
||||||
|
|
||||||
|
# A "normal" FID computation from a generator that takes LR inputs. Does not account for structural differences at all.
|
||||||
|
class FidForStructuralNetsEvaluator(evaluator.Evaluator):
|
||||||
|
def __init__(self, model, opt_eval, env):
|
||||||
|
super().__init__(model, opt_eval, env)
|
||||||
|
self.batch_sz = opt_eval['batch_size']
|
||||||
|
assert self.batch_sz is not None
|
||||||
|
self.dataset = create_dataset(opt_eval['dataset'])
|
||||||
|
self.scale = opt_eval['scale']
|
||||||
|
self.fid_real_samples = opt_eval['dataset']['paths'] # This is assumed to exist for the given dataset.
|
||||||
|
assert isinstance(self.fid_real_samples, str)
|
||||||
|
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1)
|
||||||
|
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||||
|
|
||||||
|
def perform_eval(self):
|
||||||
|
fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
|
||||||
|
os.makedirs(fid_fake_path, exist_ok=True)
|
||||||
|
counter = 0
|
||||||
|
for batch in tqdm(self.dataloader):
|
||||||
|
lq = batch['lq'].to(self.env['device'])
|
||||||
|
gen = self.model(lq)
|
||||||
|
if not isinstance(gen, list) and not isinstance(gen, tuple):
|
||||||
|
gen = [gen]
|
||||||
|
gen = gen[self.gen_output_index]
|
||||||
|
|
||||||
|
for b in range(self.batch_sz):
|
||||||
|
torchvision.utils.save_image(gen[b], osp.join(fid_fake_path, "%i_.png" % (counter)))
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,
|
||||||
|
2048)}
|
|
@ -129,10 +129,10 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||||
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||||
elif which_model == "stylegan2_discriminator":
|
elif which_model == "stylegan2_discriminator":
|
||||||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||||
disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
|
from models.stylegan.stylegan2_lucidrains import StyleGan2Discriminator
|
||||||
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
|
||||||
elif which_model == "rrdb_disc":
|
from models.stylegan.stylegan2_lucidrains import StyleGan2Augmentor
|
||||||
netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3)
|
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
return netD
|
return netD
|
||||||
|
|
Loading…
Reference in New Issue
Block a user