Rework stylegan_for_sr to incorporate structure as an adain block
This commit is contained in:
parent
519ba6f10c
commit
b10bcf6436
|
@ -222,4 +222,3 @@ class RRDBNet(nn.Module):
|
|||
for i, bm in enumerate(self.body):
|
||||
if hasattr(bm, 'bypass_map'):
|
||||
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import models.archs.srflow_orig.module_util as mutil
|
||||
from models.archs.arch_util import default_init_weights, ConvGnSilu
|
||||
from models.archs.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
@ -231,3 +231,27 @@ class RRDBNet(nn.Module):
|
|||
return results
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
class RRDBLatentWrapper(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, with_bypass, blocks, pretrain_rrdb_path=None, gc=32, scale=4):
|
||||
super().__init__()
|
||||
self.with_bypass = with_bypass
|
||||
self.blocks = blocks
|
||||
fake_opt = { 'networks': {'generator': {'flow': {'stackRRDB': {'blocks': blocks}}, 'rrdb_bypass': with_bypass}}}
|
||||
self.wrappedRRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, fake_opt)
|
||||
if pretrain_rrdb_path is not None:
|
||||
rrdb_state_dict = torch.load(pretrain_rrdb_path)
|
||||
self.wrappedRRDB.load_state_dict(rrdb_state_dict, strict=True)
|
||||
out_dim = nf * (len(blocks) + 1)
|
||||
self.postprocess = nn.Sequential(ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=True, norm=True),
|
||||
ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=True, norm=True),
|
||||
ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=False, norm=False))
|
||||
|
||||
def forward(self, lr):
|
||||
rrdbResults = self.wrappedRRDB(lr, get_steps=True)
|
||||
blocklist = [rrdbResults["block_{}".format(idx)] for idx in self.blocks]
|
||||
blocklist.append(rrdbResults['last_lr_fea'])
|
||||
fea = torch.cat(blocklist, dim=1)
|
||||
fea = self.postprocess(fea)
|
||||
return fea
|
|
@ -1,3 +1,4 @@
|
|||
import functools
|
||||
import math
|
||||
import multiprocessing
|
||||
from contextlib import contextmanager, ExitStack
|
||||
|
@ -371,6 +372,76 @@ class RGBBlock(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class AdaptiveInstanceNorm(nn.Module):
|
||||
def __init__(self, in_channel, style_dim):
|
||||
super().__init__()
|
||||
from models.archs.arch_util import ConvGnLelu
|
||||
self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True)
|
||||
self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0)
|
||||
self.norm = nn.InstanceNorm2d(in_channel)
|
||||
|
||||
def forward(self, input, style):
|
||||
gamma = self.style2scale(style)
|
||||
beta = self.style2bias(style)
|
||||
out = self.norm(input)
|
||||
out = gamma * out + beta
|
||||
return out
|
||||
|
||||
|
||||
class NoiseInjection(nn.Module):
|
||||
def __init__(self, channel):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
||||
|
||||
def forward(self, image, noise):
|
||||
return image + self.weight * noise
|
||||
|
||||
|
||||
class EqualLR:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def compute_weight(self, module):
|
||||
weight = getattr(module, self.name + '_orig')
|
||||
fan_in = weight.data.size(1) * weight.data[0][0].numel()
|
||||
|
||||
return weight * math.sqrt(2 / fan_in)
|
||||
|
||||
@staticmethod
|
||||
def apply(module, name):
|
||||
fn = EqualLR(name)
|
||||
|
||||
weight = getattr(module, name)
|
||||
del module._parameters[name]
|
||||
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
|
||||
module.register_forward_pre_hook(fn)
|
||||
|
||||
return fn
|
||||
|
||||
def __call__(self, module, input):
|
||||
weight = self.compute_weight(module)
|
||||
setattr(module, self.name, weight)
|
||||
|
||||
|
||||
def equal_lr(module, name='weight'):
|
||||
EqualLR.apply(module, name)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
conv = nn.Conv2d(*args, **kwargs)
|
||||
conv.weight.data.normal_()
|
||||
conv.bias.data.zero_()
|
||||
self.conv = equal_lr(conv)
|
||||
|
||||
def forward(self, input):
|
||||
return self.conv(input)
|
||||
|
||||
|
||||
class Conv2DMod(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs):
|
||||
super().__init__()
|
||||
|
@ -408,6 +479,54 @@ class Conv2DMod(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class GeneratorBlockWithStructure(nn.Module):
|
||||
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False):
|
||||
super().__init__()
|
||||
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
|
||||
|
||||
# Uses stylegan1 style blocks for injecting structural latent.
|
||||
self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1)
|
||||
self.to_noise0 = nn.Linear(1, filters)
|
||||
self.noise0 = equal_lr(NoiseInjection(filters))
|
||||
self.adain0 = AdaptiveInstanceNorm(filters, latent_dim)
|
||||
|
||||
self.to_style1 = nn.Linear(latent_dim, filters)
|
||||
self.to_noise1 = nn.Linear(1, filters)
|
||||
self.conv1 = Conv2DMod(filters, filters, 3)
|
||||
|
||||
self.to_style2 = nn.Linear(latent_dim, filters)
|
||||
self.to_noise2 = nn.Linear(1, filters)
|
||||
self.conv2 = Conv2DMod(filters, filters, 3)
|
||||
|
||||
self.activation = leaky_relu()
|
||||
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
|
||||
|
||||
def forward(self, x, prev_rgb, istyle, inoise, structure_input):
|
||||
if exists(self.upsample):
|
||||
x = self.upsample(x)
|
||||
|
||||
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
|
||||
noise0 = self.to_noise0(inoise).permute((0, 3, 1, 2))
|
||||
noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2))
|
||||
noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2))
|
||||
|
||||
structure = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest")
|
||||
x = self.conv0(x)
|
||||
x = self.noise0(x, noise0)
|
||||
x = self.adain0(x, structure)
|
||||
|
||||
style1 = self.to_style1(istyle)
|
||||
x = self.conv1(x, style1)
|
||||
x = self.activation(x + noise1)
|
||||
|
||||
style2 = self.to_style2(istyle)
|
||||
x = self.conv2(x, style2)
|
||||
x = self.activation(x + noise2)
|
||||
|
||||
rgb = self.to_rgb(x, prev_rgb, istyle)
|
||||
return x, rgb
|
||||
|
||||
|
||||
class GeneratorBlock(nn.Module):
|
||||
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False):
|
||||
super().__init__()
|
||||
|
@ -453,32 +572,6 @@ class GeneratorBlock(nn.Module):
|
|||
return x, rgb
|
||||
|
||||
|
||||
class DiscriminatorBlock(nn.Module):
|
||||
def __init__(self, input_channels, filters, downsample=True):
|
||||
super().__init__()
|
||||
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(input_channels, filters, 3, padding=1),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(filters, filters, 3, padding=1),
|
||||
leaky_relu()
|
||||
)
|
||||
|
||||
self.downsample = nn.Sequential(
|
||||
Blur(),
|
||||
nn.Conv2d(filters, filters, 3, padding=1, stride=2)
|
||||
) if downsample else None
|
||||
|
||||
def forward(self, x):
|
||||
res = self.conv_res(x)
|
||||
x = self.net(x)
|
||||
if exists(self.downsample):
|
||||
x = self.downsample(x)
|
||||
x = (x + res) * (1 / math.sqrt(2))
|
||||
return x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False,
|
||||
fmap_max=512, structure_input=False):
|
||||
|
@ -515,18 +608,22 @@ class Generator(nn.Module):
|
|||
|
||||
self.attns.append(attn_fn)
|
||||
|
||||
block = GeneratorBlock(
|
||||
if structure_input:
|
||||
block_fn = GeneratorBlockWithStructure
|
||||
else:
|
||||
block_fn = GeneratorBlock
|
||||
|
||||
block = block_fn(
|
||||
latent_dim,
|
||||
in_chan,
|
||||
out_chan,
|
||||
upsample=not_first,
|
||||
upsample_rgb=not_last,
|
||||
rgba=transparent,
|
||||
structure_input=structure_input
|
||||
rgba=transparent
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
def forward(self, styles, input_noise, structure_input=None):
|
||||
def forward(self, styles, input_noise, structure_input=None, starting_shape=None):
|
||||
batch_size = styles.shape[0]
|
||||
image_size = self.image_size
|
||||
|
||||
|
@ -535,6 +632,8 @@ class Generator(nn.Module):
|
|||
x = self.to_initial_block(avg_style)
|
||||
else:
|
||||
x = self.initial_block.expand(batch_size, -1, -1, -1)
|
||||
if starting_shape is not None:
|
||||
x = F.interpolate(x, size=starting_shape, mode="bilinear")
|
||||
|
||||
rgb = None
|
||||
styles = styles.transpose(0, 1)
|
||||
|
@ -591,7 +690,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
|||
|
||||
# To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format:
|
||||
# b,f,h,w.
|
||||
def forward(self, x, structure_input=None):
|
||||
def forward(self, x, structure_input=None, fit_starting_shape_to_structure=False):
|
||||
b, f, h, w = x.shape
|
||||
|
||||
full_random_latents = True
|
||||
|
@ -614,12 +713,15 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
|||
w_space = self.latent_to_w(self.vectorizer, style)
|
||||
w_styles = self.styles_def_to_tensor(w_space)
|
||||
|
||||
starting_shape = None
|
||||
if fit_starting_shape_to_structure:
|
||||
starting_shape = (x.shape[2] // 32, x.shape[3] // 32)
|
||||
# The underlying model expects the noise as b,h,w,1. Make it so.
|
||||
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input), w_styles
|
||||
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input, starting_shape), w_styles
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if type(m) in {nn.Conv2d, nn.Linear}:
|
||||
if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'):
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
for block in self.gen.blocks:
|
||||
|
@ -629,6 +731,32 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
|||
nn.init.zeros_(block.to_noise2.bias)
|
||||
|
||||
|
||||
class DiscriminatorBlock(nn.Module):
|
||||
def __init__(self, input_channels, filters, downsample=True):
|
||||
super().__init__()
|
||||
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(input_channels, filters, 3, padding=1),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(filters, filters, 3, padding=1),
|
||||
leaky_relu()
|
||||
)
|
||||
|
||||
self.downsample = nn.Sequential(
|
||||
Blur(),
|
||||
nn.Conv2d(filters, filters, 3, padding=1, stride=2)
|
||||
) if downsample else None
|
||||
|
||||
def forward(self, x):
|
||||
res = self.conv_res(x)
|
||||
x = self.net(x)
|
||||
if exists(self.downsample):
|
||||
x = self.downsample(x)
|
||||
x = (x + res) * (1 / math.sqrt(2))
|
||||
return x
|
||||
|
||||
|
||||
class StyleGan2Discriminator(nn.Module):
|
||||
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
|
||||
transparent=False, fmap_max=512, input_filters=3):
|
||||
|
|
|
@ -22,6 +22,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
|
|||
self.im_sz = opt_eval['image_size']
|
||||
self.scale = opt_eval['scale']
|
||||
self.fid_real_samples = opt_eval['real_fid_path']
|
||||
self.embedding_generator = opt_eval['embedding_generator']
|
||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||
self.dataset = Stylegan2Dataset({'path': self.fid_real_samples,
|
||||
'target_size': self.im_sz,
|
||||
|
@ -30,6 +31,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
|
|||
self.sampler = BatchSampler(self.dataset, self.batch_sz, False)
|
||||
|
||||
def perform_eval(self):
|
||||
embedding_generator = self.env['generators'][self.embedding_generator]
|
||||
fid_fake_path = osp.join(self.env['base_path'], "..", "fid_fake", str(self.env["step"]))
|
||||
os.makedirs(fid_fake_path, exist_ok=True)
|
||||
fid_real_path = osp.join(self.env['base_path'], "..", "fid_real", str(self.env["step"]))
|
||||
|
@ -40,7 +42,8 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
|
|||
batch_hq = [e['GT'] for e in batch]
|
||||
batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device'])
|
||||
resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area")
|
||||
gen = self.model(noise, resized_batch)
|
||||
embedding = embedding_generator(resized_batch)
|
||||
gen = self.model(noise, embedding)
|
||||
if not isinstance(gen, list) and not isinstance(gen, tuple):
|
||||
gen = [gen]
|
||||
gen = gen[self.gen_output_index]
|
||||
|
|
|
@ -148,6 +148,11 @@ def define_G(opt, opt_net, scale=None):
|
|||
from models.archs.srflow_orig import SRFlowNet_arch
|
||||
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'],
|
||||
K=opt_net['K'], opt=opt)
|
||||
elif which_model == 'rrdb_latent_wrapper':
|
||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper
|
||||
netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
|
||||
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
return netG
|
||||
|
|
|
@ -78,15 +78,20 @@ class Injector(torch.nn.Module):
|
|||
class ImageGeneratorInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ImageGeneratorInjector, self).__init__(opt, env)
|
||||
self.grad = opt['grad'] if 'grad' in opt.keys() else True
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
if isinstance(self.input, list):
|
||||
params = extract_params_from_state(self.input, state)
|
||||
else:
|
||||
params = [state[self.input]]
|
||||
if self.grad:
|
||||
results = gen(*params)
|
||||
else:
|
||||
results = gen(state[self.input])
|
||||
with torch.no_grad():
|
||||
results = gen(*params)
|
||||
new_state = {}
|
||||
if isinstance(self.output, list):
|
||||
# Only dereference tuples or lists, not tensors.
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch
|
|||
def main():
|
||||
split_img = False
|
||||
opt = {}
|
||||
opt['n_thread'] = 5
|
||||
opt['n_thread'] = 20
|
||||
opt['compression_level'] = 90 # JPEG compression quality rating.
|
||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||
|
@ -46,6 +46,9 @@ class TiledDataset(data.Dataset):
|
|||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
# Greyscale not supported.
|
||||
if img is None:
|
||||
print("Error with ", path)
|
||||
return None
|
||||
if len(img.shape) == 2:
|
||||
return None
|
||||
h, w, c = img.shape
|
||||
|
|
|
@ -31,8 +31,8 @@ class Trainer:
|
|||
|
||||
def init(self, opt, launcher, all_networks={}):
|
||||
self._profile = False
|
||||
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
|
||||
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
|
||||
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'].keys() else True
|
||||
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'].keys() else True
|
||||
|
||||
#### loading resume state if exists
|
||||
if opt['path'].get('resume_state', None):
|
||||
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srflow.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_v2.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_srflow.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user