forked from mrq/DL-Art-School
Rework stylegan_for_sr to incorporate structure as an adain block
This commit is contained in:
parent
519ba6f10c
commit
b10bcf6436
|
@ -222,4 +222,3 @@ class RRDBNet(nn.Module):
|
||||||
for i, bm in enumerate(self.body):
|
for i, bm in enumerate(self.body):
|
||||||
if hasattr(bm, 'bypass_map'):
|
if hasattr(bm, 'bypass_map'):
|
||||||
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import models.archs.srflow_orig.module_util as mutil
|
import models.archs.srflow_orig.module_util as mutil
|
||||||
from models.archs.arch_util import default_init_weights, ConvGnSilu
|
from models.archs.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
@ -231,3 +231,27 @@ class RRDBNet(nn.Module):
|
||||||
return results
|
return results
|
||||||
else:
|
else:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class RRDBLatentWrapper(nn.Module):
|
||||||
|
def __init__(self, in_nc, out_nc, nf, nb, with_bypass, blocks, pretrain_rrdb_path=None, gc=32, scale=4):
|
||||||
|
super().__init__()
|
||||||
|
self.with_bypass = with_bypass
|
||||||
|
self.blocks = blocks
|
||||||
|
fake_opt = { 'networks': {'generator': {'flow': {'stackRRDB': {'blocks': blocks}}, 'rrdb_bypass': with_bypass}}}
|
||||||
|
self.wrappedRRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, fake_opt)
|
||||||
|
if pretrain_rrdb_path is not None:
|
||||||
|
rrdb_state_dict = torch.load(pretrain_rrdb_path)
|
||||||
|
self.wrappedRRDB.load_state_dict(rrdb_state_dict, strict=True)
|
||||||
|
out_dim = nf * (len(blocks) + 1)
|
||||||
|
self.postprocess = nn.Sequential(ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=True, norm=True),
|
||||||
|
ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=True, norm=True),
|
||||||
|
ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=False, norm=False))
|
||||||
|
|
||||||
|
def forward(self, lr):
|
||||||
|
rrdbResults = self.wrappedRRDB(lr, get_steps=True)
|
||||||
|
blocklist = [rrdbResults["block_{}".format(idx)] for idx in self.blocks]
|
||||||
|
blocklist.append(rrdbResults['last_lr_fea'])
|
||||||
|
fea = torch.cat(blocklist, dim=1)
|
||||||
|
fea = self.postprocess(fea)
|
||||||
|
return fea
|
|
@ -1,3 +1,4 @@
|
||||||
|
import functools
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from contextlib import contextmanager, ExitStack
|
from contextlib import contextmanager, ExitStack
|
||||||
|
@ -371,6 +372,76 @@ class RGBBlock(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveInstanceNorm(nn.Module):
|
||||||
|
def __init__(self, in_channel, style_dim):
|
||||||
|
super().__init__()
|
||||||
|
from models.archs.arch_util import ConvGnLelu
|
||||||
|
self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True)
|
||||||
|
self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0)
|
||||||
|
self.norm = nn.InstanceNorm2d(in_channel)
|
||||||
|
|
||||||
|
def forward(self, input, style):
|
||||||
|
gamma = self.style2scale(style)
|
||||||
|
beta = self.style2bias(style)
|
||||||
|
out = self.norm(input)
|
||||||
|
out = gamma * out + beta
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseInjection(nn.Module):
|
||||||
|
def __init__(self, channel):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
||||||
|
|
||||||
|
def forward(self, image, noise):
|
||||||
|
return image + self.weight * noise
|
||||||
|
|
||||||
|
|
||||||
|
class EqualLR:
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def compute_weight(self, module):
|
||||||
|
weight = getattr(module, self.name + '_orig')
|
||||||
|
fan_in = weight.data.size(1) * weight.data[0][0].numel()
|
||||||
|
|
||||||
|
return weight * math.sqrt(2 / fan_in)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply(module, name):
|
||||||
|
fn = EqualLR(name)
|
||||||
|
|
||||||
|
weight = getattr(module, name)
|
||||||
|
del module._parameters[name]
|
||||||
|
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
|
||||||
|
module.register_forward_pre_hook(fn)
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
def __call__(self, module, input):
|
||||||
|
weight = self.compute_weight(module)
|
||||||
|
setattr(module, self.name, weight)
|
||||||
|
|
||||||
|
|
||||||
|
def equal_lr(module, name='weight'):
|
||||||
|
EqualLR.apply(module, name)
|
||||||
|
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
class EqualConv2d(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
conv = nn.Conv2d(*args, **kwargs)
|
||||||
|
conv.weight.data.normal_()
|
||||||
|
conv.bias.data.zero_()
|
||||||
|
self.conv = equal_lr(conv)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return self.conv(input)
|
||||||
|
|
||||||
|
|
||||||
class Conv2DMod(nn.Module):
|
class Conv2DMod(nn.Module):
|
||||||
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs):
|
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -408,6 +479,54 @@ class Conv2DMod(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorBlockWithStructure(nn.Module):
|
||||||
|
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False):
|
||||||
|
super().__init__()
|
||||||
|
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
|
||||||
|
|
||||||
|
# Uses stylegan1 style blocks for injecting structural latent.
|
||||||
|
self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1)
|
||||||
|
self.to_noise0 = nn.Linear(1, filters)
|
||||||
|
self.noise0 = equal_lr(NoiseInjection(filters))
|
||||||
|
self.adain0 = AdaptiveInstanceNorm(filters, latent_dim)
|
||||||
|
|
||||||
|
self.to_style1 = nn.Linear(latent_dim, filters)
|
||||||
|
self.to_noise1 = nn.Linear(1, filters)
|
||||||
|
self.conv1 = Conv2DMod(filters, filters, 3)
|
||||||
|
|
||||||
|
self.to_style2 = nn.Linear(latent_dim, filters)
|
||||||
|
self.to_noise2 = nn.Linear(1, filters)
|
||||||
|
self.conv2 = Conv2DMod(filters, filters, 3)
|
||||||
|
|
||||||
|
self.activation = leaky_relu()
|
||||||
|
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
|
||||||
|
|
||||||
|
def forward(self, x, prev_rgb, istyle, inoise, structure_input):
|
||||||
|
if exists(self.upsample):
|
||||||
|
x = self.upsample(x)
|
||||||
|
|
||||||
|
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
|
||||||
|
noise0 = self.to_noise0(inoise).permute((0, 3, 1, 2))
|
||||||
|
noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2))
|
||||||
|
noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2))
|
||||||
|
|
||||||
|
structure = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest")
|
||||||
|
x = self.conv0(x)
|
||||||
|
x = self.noise0(x, noise0)
|
||||||
|
x = self.adain0(x, structure)
|
||||||
|
|
||||||
|
style1 = self.to_style1(istyle)
|
||||||
|
x = self.conv1(x, style1)
|
||||||
|
x = self.activation(x + noise1)
|
||||||
|
|
||||||
|
style2 = self.to_style2(istyle)
|
||||||
|
x = self.conv2(x, style2)
|
||||||
|
x = self.activation(x + noise2)
|
||||||
|
|
||||||
|
rgb = self.to_rgb(x, prev_rgb, istyle)
|
||||||
|
return x, rgb
|
||||||
|
|
||||||
|
|
||||||
class GeneratorBlock(nn.Module):
|
class GeneratorBlock(nn.Module):
|
||||||
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False):
|
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -453,32 +572,6 @@ class GeneratorBlock(nn.Module):
|
||||||
return x, rgb
|
return x, rgb
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorBlock(nn.Module):
|
|
||||||
def __init__(self, input_channels, filters, downsample=True):
|
|
||||||
super().__init__()
|
|
||||||
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
|
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
nn.Conv2d(input_channels, filters, 3, padding=1),
|
|
||||||
leaky_relu(),
|
|
||||||
nn.Conv2d(filters, filters, 3, padding=1),
|
|
||||||
leaky_relu()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.downsample = nn.Sequential(
|
|
||||||
Blur(),
|
|
||||||
nn.Conv2d(filters, filters, 3, padding=1, stride=2)
|
|
||||||
) if downsample else None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
res = self.conv_res(x)
|
|
||||||
x = self.net(x)
|
|
||||||
if exists(self.downsample):
|
|
||||||
x = self.downsample(x)
|
|
||||||
x = (x + res) * (1 / math.sqrt(2))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False,
|
def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False,
|
||||||
fmap_max=512, structure_input=False):
|
fmap_max=512, structure_input=False):
|
||||||
|
@ -515,18 +608,22 @@ class Generator(nn.Module):
|
||||||
|
|
||||||
self.attns.append(attn_fn)
|
self.attns.append(attn_fn)
|
||||||
|
|
||||||
block = GeneratorBlock(
|
if structure_input:
|
||||||
|
block_fn = GeneratorBlockWithStructure
|
||||||
|
else:
|
||||||
|
block_fn = GeneratorBlock
|
||||||
|
|
||||||
|
block = block_fn(
|
||||||
latent_dim,
|
latent_dim,
|
||||||
in_chan,
|
in_chan,
|
||||||
out_chan,
|
out_chan,
|
||||||
upsample=not_first,
|
upsample=not_first,
|
||||||
upsample_rgb=not_last,
|
upsample_rgb=not_last,
|
||||||
rgba=transparent,
|
rgba=transparent
|
||||||
structure_input=structure_input
|
|
||||||
)
|
)
|
||||||
self.blocks.append(block)
|
self.blocks.append(block)
|
||||||
|
|
||||||
def forward(self, styles, input_noise, structure_input=None):
|
def forward(self, styles, input_noise, structure_input=None, starting_shape=None):
|
||||||
batch_size = styles.shape[0]
|
batch_size = styles.shape[0]
|
||||||
image_size = self.image_size
|
image_size = self.image_size
|
||||||
|
|
||||||
|
@ -535,6 +632,8 @@ class Generator(nn.Module):
|
||||||
x = self.to_initial_block(avg_style)
|
x = self.to_initial_block(avg_style)
|
||||||
else:
|
else:
|
||||||
x = self.initial_block.expand(batch_size, -1, -1, -1)
|
x = self.initial_block.expand(batch_size, -1, -1, -1)
|
||||||
|
if starting_shape is not None:
|
||||||
|
x = F.interpolate(x, size=starting_shape, mode="bilinear")
|
||||||
|
|
||||||
rgb = None
|
rgb = None
|
||||||
styles = styles.transpose(0, 1)
|
styles = styles.transpose(0, 1)
|
||||||
|
@ -591,7 +690,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
||||||
|
|
||||||
# To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format:
|
# To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format:
|
||||||
# b,f,h,w.
|
# b,f,h,w.
|
||||||
def forward(self, x, structure_input=None):
|
def forward(self, x, structure_input=None, fit_starting_shape_to_structure=False):
|
||||||
b, f, h, w = x.shape
|
b, f, h, w = x.shape
|
||||||
|
|
||||||
full_random_latents = True
|
full_random_latents = True
|
||||||
|
@ -614,12 +713,15 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
||||||
w_space = self.latent_to_w(self.vectorizer, style)
|
w_space = self.latent_to_w(self.vectorizer, style)
|
||||||
w_styles = self.styles_def_to_tensor(w_space)
|
w_styles = self.styles_def_to_tensor(w_space)
|
||||||
|
|
||||||
|
starting_shape = None
|
||||||
|
if fit_starting_shape_to_structure:
|
||||||
|
starting_shape = (x.shape[2] // 32, x.shape[3] // 32)
|
||||||
# The underlying model expects the noise as b,h,w,1. Make it so.
|
# The underlying model expects the noise as b,h,w,1. Make it so.
|
||||||
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input), w_styles
|
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input, starting_shape), w_styles
|
||||||
|
|
||||||
def _init_weights(self):
|
def _init_weights(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if type(m) in {nn.Conv2d, nn.Linear}:
|
if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'):
|
||||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||||
|
|
||||||
for block in self.gen.blocks:
|
for block in self.gen.blocks:
|
||||||
|
@ -629,6 +731,32 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
||||||
nn.init.zeros_(block.to_noise2.bias)
|
nn.init.zeros_(block.to_noise2.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorBlock(nn.Module):
|
||||||
|
def __init__(self, input_channels, filters, downsample=True):
|
||||||
|
super().__init__()
|
||||||
|
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
|
||||||
|
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(input_channels, filters, 3, padding=1),
|
||||||
|
leaky_relu(),
|
||||||
|
nn.Conv2d(filters, filters, 3, padding=1),
|
||||||
|
leaky_relu()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
Blur(),
|
||||||
|
nn.Conv2d(filters, filters, 3, padding=1, stride=2)
|
||||||
|
) if downsample else None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = self.conv_res(x)
|
||||||
|
x = self.net(x)
|
||||||
|
if exists(self.downsample):
|
||||||
|
x = self.downsample(x)
|
||||||
|
x = (x + res) * (1 / math.sqrt(2))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class StyleGan2Discriminator(nn.Module):
|
class StyleGan2Discriminator(nn.Module):
|
||||||
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
|
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
|
||||||
transparent=False, fmap_max=512, input_filters=3):
|
transparent=False, fmap_max=512, input_filters=3):
|
||||||
|
|
|
@ -22,6 +22,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
|
||||||
self.im_sz = opt_eval['image_size']
|
self.im_sz = opt_eval['image_size']
|
||||||
self.scale = opt_eval['scale']
|
self.scale = opt_eval['scale']
|
||||||
self.fid_real_samples = opt_eval['real_fid_path']
|
self.fid_real_samples = opt_eval['real_fid_path']
|
||||||
|
self.embedding_generator = opt_eval['embedding_generator']
|
||||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||||
self.dataset = Stylegan2Dataset({'path': self.fid_real_samples,
|
self.dataset = Stylegan2Dataset({'path': self.fid_real_samples,
|
||||||
'target_size': self.im_sz,
|
'target_size': self.im_sz,
|
||||||
|
@ -30,6 +31,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
|
||||||
self.sampler = BatchSampler(self.dataset, self.batch_sz, False)
|
self.sampler = BatchSampler(self.dataset, self.batch_sz, False)
|
||||||
|
|
||||||
def perform_eval(self):
|
def perform_eval(self):
|
||||||
|
embedding_generator = self.env['generators'][self.embedding_generator]
|
||||||
fid_fake_path = osp.join(self.env['base_path'], "..", "fid_fake", str(self.env["step"]))
|
fid_fake_path = osp.join(self.env['base_path'], "..", "fid_fake", str(self.env["step"]))
|
||||||
os.makedirs(fid_fake_path, exist_ok=True)
|
os.makedirs(fid_fake_path, exist_ok=True)
|
||||||
fid_real_path = osp.join(self.env['base_path'], "..", "fid_real", str(self.env["step"]))
|
fid_real_path = osp.join(self.env['base_path'], "..", "fid_real", str(self.env["step"]))
|
||||||
|
@ -40,7 +42,8 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
|
||||||
batch_hq = [e['GT'] for e in batch]
|
batch_hq = [e['GT'] for e in batch]
|
||||||
batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device'])
|
batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device'])
|
||||||
resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area")
|
resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area")
|
||||||
gen = self.model(noise, resized_batch)
|
embedding = embedding_generator(resized_batch)
|
||||||
|
gen = self.model(noise, embedding)
|
||||||
if not isinstance(gen, list) and not isinstance(gen, tuple):
|
if not isinstance(gen, list) and not isinstance(gen, tuple):
|
||||||
gen = [gen]
|
gen = [gen]
|
||||||
gen = gen[self.gen_output_index]
|
gen = gen[self.gen_output_index]
|
||||||
|
|
|
@ -148,6 +148,11 @@ def define_G(opt, opt_net, scale=None):
|
||||||
from models.archs.srflow_orig import SRFlowNet_arch
|
from models.archs.srflow_orig import SRFlowNet_arch
|
||||||
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'],
|
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'],
|
||||||
K=opt_net['K'], opt=opt)
|
K=opt_net['K'], opt=opt)
|
||||||
|
elif which_model == 'rrdb_latent_wrapper':
|
||||||
|
from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper
|
||||||
|
netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
|
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
|
||||||
|
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
return netG
|
return netG
|
||||||
|
|
|
@ -78,15 +78,20 @@ class Injector(torch.nn.Module):
|
||||||
class ImageGeneratorInjector(Injector):
|
class ImageGeneratorInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(ImageGeneratorInjector, self).__init__(opt, env)
|
super(ImageGeneratorInjector, self).__init__(opt, env)
|
||||||
|
self.grad = opt['grad'] if 'grad' in opt.keys() else True
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
with autocast(enabled=self.env['opt']['fp16']):
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
if isinstance(self.input, list):
|
if isinstance(self.input, list):
|
||||||
params = extract_params_from_state(self.input, state)
|
params = extract_params_from_state(self.input, state)
|
||||||
|
else:
|
||||||
|
params = [state[self.input]]
|
||||||
|
if self.grad:
|
||||||
results = gen(*params)
|
results = gen(*params)
|
||||||
else:
|
else:
|
||||||
results = gen(state[self.input])
|
with torch.no_grad():
|
||||||
|
results = gen(*params)
|
||||||
new_state = {}
|
new_state = {}
|
||||||
if isinstance(self.output, list):
|
if isinstance(self.output, list):
|
||||||
# Only dereference tuples or lists, not tensors.
|
# Only dereference tuples or lists, not tensors.
|
||||||
|
|
|
@ -13,7 +13,7 @@ import torch
|
||||||
def main():
|
def main():
|
||||||
split_img = False
|
split_img = False
|
||||||
opt = {}
|
opt = {}
|
||||||
opt['n_thread'] = 5
|
opt['n_thread'] = 20
|
||||||
opt['compression_level'] = 90 # JPEG compression quality rating.
|
opt['compression_level'] = 90 # JPEG compression quality rating.
|
||||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||||
|
@ -46,6 +46,9 @@ class TiledDataset(data.Dataset):
|
||||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
# Greyscale not supported.
|
# Greyscale not supported.
|
||||||
|
if img is None:
|
||||||
|
print("Error with ", path)
|
||||||
|
return None
|
||||||
if len(img.shape) == 2:
|
if len(img.shape) == 2:
|
||||||
return None
|
return None
|
||||||
h, w, c = img.shape
|
h, w, c = img.shape
|
||||||
|
|
|
@ -31,8 +31,8 @@ class Trainer:
|
||||||
|
|
||||||
def init(self, opt, launcher, all_networks={}):
|
def init(self, opt, launcher, all_networks={}):
|
||||||
self._profile = False
|
self._profile = False
|
||||||
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
|
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'].keys() else True
|
||||||
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
|
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'].keys() else True
|
||||||
|
|
||||||
#### loading resume state if exists
|
#### loading resume state if exists
|
||||||
if opt['path'].get('resume_state', None):
|
if opt['path'].get('resume_state', None):
|
||||||
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srflow.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_v2.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_srflow.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user