More latent work
This commit is contained in:
parent
6be6c92e5d
commit
9e2c96ad5d
|
@ -134,13 +134,15 @@ class RRDBNetWithLatent(nn.Module):
|
||||||
num_blocks=23,
|
num_blocks=23,
|
||||||
growth_channels=32,
|
growth_channels=32,
|
||||||
blocks_per_checkpoint=4,
|
blocks_per_checkpoint=4,
|
||||||
scale=4):
|
scale=4,
|
||||||
|
bottom_latent_only=False):
|
||||||
super(RRDBNetWithLatent, self).__init__()
|
super(RRDBNetWithLatent, self).__init__()
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.blocks_per_checkpoint = blocks_per_checkpoint
|
self.blocks_per_checkpoint = blocks_per_checkpoint
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.nf = mid_channels
|
self.nf = mid_channels
|
||||||
|
self.bottom_latent_only = bottom_latent_only
|
||||||
first_conv_stride = 1 if in_channels <= 4 else scale
|
first_conv_stride = 1 if in_channels <= 4 else scale
|
||||||
first_conv_ksize = 3 if first_conv_stride == 1 else 7
|
first_conv_ksize = 3 if first_conv_stride == 1 else 7
|
||||||
first_conv_padding = 1 if first_conv_stride == 1 else 3
|
first_conv_padding = 1 if first_conv_stride == 1 else 3
|
||||||
|
@ -172,6 +174,9 @@ class RRDBNetWithLatent(nn.Module):
|
||||||
mults = [4, 2, 1]
|
mults = [4, 2, 1]
|
||||||
b, f, h, w = x.shape
|
b, f, h, w = x.shape
|
||||||
latent = [torch.randn((b, self.nf * m, h // m, w // m), dtype=torch.float, device=x.device) for m in mults]
|
latent = [torch.randn((b, self.nf * m, h // m, w // m), dtype=torch.float, device=x.device) for m in mults]
|
||||||
|
if self.bottom_latent_only:
|
||||||
|
latent[1] = torch.zeros_like(latent[1])
|
||||||
|
latent[2] = torch.zeros_like(latent[2])
|
||||||
latent = self.latent_encoder(latent)
|
latent = self.latent_encoder(latent)
|
||||||
if latent_was_none is None:
|
if latent_was_none is None:
|
||||||
self.latent_mean = torch.mean(latent).detach().cpu()
|
self.latent_mean = torch.mean(latent).detach().cpu()
|
||||||
|
@ -289,3 +294,80 @@ class LatentEstimator(nn.Module):
|
||||||
'latent_estimator_std': self.latent_std,
|
'latent_estimator_std': self.latent_std,
|
||||||
'latent_estimator_var': self.latent_var}
|
'latent_estimator_var': self.latent_var}
|
||||||
|
|
||||||
|
|
||||||
|
class LatentEstimator2(nn.Module):
|
||||||
|
def __init__(self, in_nc, nf):
|
||||||
|
super(LatentEstimator2, self).__init__()
|
||||||
|
# [64, 128, 128]
|
||||||
|
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
|
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||||
|
self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
|
||||||
|
# [64, 64, 64]
|
||||||
|
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||||
|
self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||||
|
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||||
|
self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||||
|
|
||||||
|
# [128, 32, 32]
|
||||||
|
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||||
|
self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||||
|
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||||
|
self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||||
|
|
||||||
|
# [256, 16, 16]
|
||||||
|
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
|
||||||
|
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||||
|
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||||
|
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||||
|
|
||||||
|
# [256, 8, 8]
|
||||||
|
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||||
|
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||||
|
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||||
|
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||||
|
|
||||||
|
# [256, 4, 4]
|
||||||
|
self.conv5_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||||
|
self.bn5_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||||
|
self.conv5_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||||
|
self.bn5_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||||
|
self.l = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, activation=True, norm=True, bias=True)
|
||||||
|
self.l2 = ConvGnLelu(nf * 4, nf * 4, kernel_size=1, activation=False, norm=False, bias=True)
|
||||||
|
|
||||||
|
self.lrelu = nn.LeakyReLU(.2, inplace=True)
|
||||||
|
self.norm = nn.InstanceNorm2d(nf*4)
|
||||||
|
|
||||||
|
def compute_body(self, x):
|
||||||
|
fea = self.lrelu(self.bn1_0(self.conv1_0(x)))
|
||||||
|
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
|
||||||
|
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
|
||||||
|
fea = self.lrelu(self.bn5_0(self.conv5_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn5_1(self.conv5_1(fea)))
|
||||||
|
o3 = self.norm(self.l2(self.l(fea)))
|
||||||
|
|
||||||
|
return F.interpolate(o3, scale_factor=4, mode="nearest")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fea = self.lrelu(self.conv0_0(x))
|
||||||
|
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
||||||
|
o = checkpoint(self.compute_body, fea)
|
||||||
|
out = [o,\
|
||||||
|
torch.zeros((o.shape[0],128,16,16), device=o.device),\
|
||||||
|
torch.zeros((o.shape[0],64,32,32), device=o.device)]
|
||||||
|
self.latent_mean = torch.mean(out[-1])
|
||||||
|
self.latent_std = torch.std(out[-1])
|
||||||
|
self.latent_var = torch.var(out[-1])
|
||||||
|
return out
|
||||||
|
|
||||||
|
def get_debug_values(self, s, n):
|
||||||
|
return {'latent_estimator_mean': self.latent_mean,
|
||||||
|
'latent_estimator_std': self.latent_std,
|
||||||
|
'latent_estimator_var': self.latent_var}
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ import models.archs.rcan as rcan
|
||||||
import models.archs.ChainedEmbeddingGen as chained
|
import models.archs.ChainedEmbeddingGen as chained
|
||||||
from models.archs import srg2_classic
|
from models.archs import srg2_classic
|
||||||
from models.archs.pyramid_arch import BasicResamplingFlowNet
|
from models.archs.pyramid_arch import BasicResamplingFlowNet
|
||||||
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent
|
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2
|
||||||
from models.archs.teco_resgen import TecoGen
|
from models.archs.teco_resgen import TecoGen
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
@ -122,10 +122,15 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
elif which_model == "rrdb_with_latent":
|
elif which_model == "rrdb_with_latent":
|
||||||
netG = RRDBNetWithLatent(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
netG = RRDBNetWithLatent(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
|
||||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'])
|
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'],
|
||||||
|
scale=opt_net['scale'],
|
||||||
|
bottom_latent_only=opt_net['bottom_latent_only'])
|
||||||
elif which_model == "latent_estimator":
|
elif which_model == "latent_estimator":
|
||||||
overwrite = [1,2] if opt_net['only_base_level'] else []
|
if opt_net['version'] == 2:
|
||||||
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite)
|
netG = LatentEstimator2(in_nc=3, nf=opt_net['nf'])
|
||||||
|
else:
|
||||||
|
overwrite = [1,2] if opt_net['only_base_level'] else []
|
||||||
|
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
return netG
|
return netG
|
||||||
|
|
|
@ -280,7 +280,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal_2.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user