More latent work

This commit is contained in:
James Betker 2020-11-07 20:38:56 -07:00
parent 6be6c92e5d
commit 9e2c96ad5d
3 changed files with 93 additions and 6 deletions

View File

@ -134,13 +134,15 @@ class RRDBNetWithLatent(nn.Module):
num_blocks=23,
growth_channels=32,
blocks_per_checkpoint=4,
scale=4):
scale=4,
bottom_latent_only=False):
super(RRDBNetWithLatent, self).__init__()
self.num_blocks = num_blocks
self.blocks_per_checkpoint = blocks_per_checkpoint
self.scale = scale
self.in_channels = in_channels
self.nf = mid_channels
self.bottom_latent_only = bottom_latent_only
first_conv_stride = 1 if in_channels <= 4 else scale
first_conv_ksize = 3 if first_conv_stride == 1 else 7
first_conv_padding = 1 if first_conv_stride == 1 else 3
@ -172,6 +174,9 @@ class RRDBNetWithLatent(nn.Module):
mults = [4, 2, 1]
b, f, h, w = x.shape
latent = [torch.randn((b, self.nf * m, h // m, w // m), dtype=torch.float, device=x.device) for m in mults]
if self.bottom_latent_only:
latent[1] = torch.zeros_like(latent[1])
latent[2] = torch.zeros_like(latent[2])
latent = self.latent_encoder(latent)
if latent_was_none is None:
self.latent_mean = torch.mean(latent).detach().cpu()
@ -289,3 +294,80 @@ class LatentEstimator(nn.Module):
'latent_estimator_std': self.latent_std,
'latent_estimator_var': self.latent_var}
class LatentEstimator2(nn.Module):
def __init__(self, in_nc, nf):
super(LatentEstimator2, self).__init__()
# [64, 128, 128]
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
# [64, 64, 64]
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
# [128, 32, 32]
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
# [256, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
# [256, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
# [256, 4, 4]
self.conv5_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn5_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv5_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn5_1 = nn.BatchNorm2d(nf * 8, affine=True)
self.l = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, activation=True, norm=True, bias=True)
self.l2 = ConvGnLelu(nf * 4, nf * 4, kernel_size=1, activation=False, norm=False, bias=True)
self.lrelu = nn.LeakyReLU(.2, inplace=True)
self.norm = nn.InstanceNorm2d(nf*4)
def compute_body(self, x):
fea = self.lrelu(self.bn1_0(self.conv1_0(x)))
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
fea = self.lrelu(self.bn5_0(self.conv5_0(fea)))
fea = self.lrelu(self.bn5_1(self.conv5_1(fea)))
o3 = self.norm(self.l2(self.l(fea)))
return F.interpolate(o3, scale_factor=4, mode="nearest")
def forward(self, x):
fea = self.lrelu(self.conv0_0(x))
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
o = checkpoint(self.compute_body, fea)
out = [o,\
torch.zeros((o.shape[0],128,16,16), device=o.device),\
torch.zeros((o.shape[0],64,32,32), device=o.device)]
self.latent_mean = torch.mean(out[-1])
self.latent_std = torch.std(out[-1])
self.latent_var = torch.var(out[-1])
return out
def get_debug_values(self, s, n):
return {'latent_estimator_mean': self.latent_mean,
'latent_estimator_std': self.latent_std,
'latent_estimator_var': self.latent_var}

View File

@ -20,7 +20,7 @@ import models.archs.rcan as rcan
import models.archs.ChainedEmbeddingGen as chained
from models.archs import srg2_classic
from models.archs.pyramid_arch import BasicResamplingFlowNet
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2
from models.archs.teco_resgen import TecoGen
logger = logging.getLogger('base')
@ -122,10 +122,15 @@ def define_G(opt, net_key='network_G', scale=None):
elif which_model == "rrdb_with_latent":
netG = RRDBNetWithLatent(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'])
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'],
scale=opt_net['scale'],
bottom_latent_only=opt_net['bottom_latent_only'])
elif which_model == "latent_estimator":
overwrite = [1,2] if opt_net['only_base_level'] else []
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite)
if opt_net['version'] == 2:
netG = LatentEstimator2(in_nc=3, nf=opt_net['nf'])
else:
overwrite = [1,2] if opt_net['only_base_level'] else []
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite)
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG

View File

@ -280,7 +280,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal_2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)