More latent work
This commit is contained in:
parent
6be6c92e5d
commit
9e2c96ad5d
|
@ -134,13 +134,15 @@ class RRDBNetWithLatent(nn.Module):
|
|||
num_blocks=23,
|
||||
growth_channels=32,
|
||||
blocks_per_checkpoint=4,
|
||||
scale=4):
|
||||
scale=4,
|
||||
bottom_latent_only=False):
|
||||
super(RRDBNetWithLatent, self).__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.blocks_per_checkpoint = blocks_per_checkpoint
|
||||
self.scale = scale
|
||||
self.in_channels = in_channels
|
||||
self.nf = mid_channels
|
||||
self.bottom_latent_only = bottom_latent_only
|
||||
first_conv_stride = 1 if in_channels <= 4 else scale
|
||||
first_conv_ksize = 3 if first_conv_stride == 1 else 7
|
||||
first_conv_padding = 1 if first_conv_stride == 1 else 3
|
||||
|
@ -172,6 +174,9 @@ class RRDBNetWithLatent(nn.Module):
|
|||
mults = [4, 2, 1]
|
||||
b, f, h, w = x.shape
|
||||
latent = [torch.randn((b, self.nf * m, h // m, w // m), dtype=torch.float, device=x.device) for m in mults]
|
||||
if self.bottom_latent_only:
|
||||
latent[1] = torch.zeros_like(latent[1])
|
||||
latent[2] = torch.zeros_like(latent[2])
|
||||
latent = self.latent_encoder(latent)
|
||||
if latent_was_none is None:
|
||||
self.latent_mean = torch.mean(latent).detach().cpu()
|
||||
|
@ -289,3 +294,80 @@ class LatentEstimator(nn.Module):
|
|||
'latent_estimator_std': self.latent_std,
|
||||
'latent_estimator_var': self.latent_var}
|
||||
|
||||
|
||||
class LatentEstimator2(nn.Module):
|
||||
def __init__(self, in_nc, nf):
|
||||
super(LatentEstimator2, self).__init__()
|
||||
# [64, 128, 128]
|
||||
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||
self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
|
||||
# [64, 64, 64]
|
||||
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||
self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||
self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
|
||||
# [128, 32, 32]
|
||||
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||
self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||
self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
|
||||
# [256, 16, 16]
|
||||
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
|
||||
# [256, 8, 8]
|
||||
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
|
||||
# [256, 4, 4]
|
||||
self.conv5_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn5_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
self.conv5_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn5_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
self.l = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, activation=True, norm=True, bias=True)
|
||||
self.l2 = ConvGnLelu(nf * 4, nf * 4, kernel_size=1, activation=False, norm=False, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(.2, inplace=True)
|
||||
self.norm = nn.InstanceNorm2d(nf*4)
|
||||
|
||||
def compute_body(self, x):
|
||||
fea = self.lrelu(self.bn1_0(self.conv1_0(x)))
|
||||
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
|
||||
|
||||
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
|
||||
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
|
||||
|
||||
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
|
||||
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
|
||||
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
|
||||
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
|
||||
fea = self.lrelu(self.bn5_0(self.conv5_0(fea)))
|
||||
fea = self.lrelu(self.bn5_1(self.conv5_1(fea)))
|
||||
o3 = self.norm(self.l2(self.l(fea)))
|
||||
|
||||
return F.interpolate(o3, scale_factor=4, mode="nearest")
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.lrelu(self.conv0_0(x))
|
||||
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
||||
o = checkpoint(self.compute_body, fea)
|
||||
out = [o,\
|
||||
torch.zeros((o.shape[0],128,16,16), device=o.device),\
|
||||
torch.zeros((o.shape[0],64,32,32), device=o.device)]
|
||||
self.latent_mean = torch.mean(out[-1])
|
||||
self.latent_std = torch.std(out[-1])
|
||||
self.latent_var = torch.var(out[-1])
|
||||
return out
|
||||
|
||||
def get_debug_values(self, s, n):
|
||||
return {'latent_estimator_mean': self.latent_mean,
|
||||
'latent_estimator_std': self.latent_std,
|
||||
'latent_estimator_var': self.latent_var}
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import models.archs.rcan as rcan
|
|||
import models.archs.ChainedEmbeddingGen as chained
|
||||
from models.archs import srg2_classic
|
||||
from models.archs.pyramid_arch import BasicResamplingFlowNet
|
||||
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent
|
||||
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2
|
||||
from models.archs.teco_resgen import TecoGen
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
@ -122,8 +122,13 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
elif which_model == "rrdb_with_latent":
|
||||
netG = RRDBNetWithLatent(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
|
||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'])
|
||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'],
|
||||
scale=opt_net['scale'],
|
||||
bottom_latent_only=opt_net['bottom_latent_only'])
|
||||
elif which_model == "latent_estimator":
|
||||
if opt_net['version'] == 2:
|
||||
netG = LatentEstimator2(in_nc=3, nf=opt_net['nf'])
|
||||
else:
|
||||
overwrite = [1,2] if opt_net['only_base_level'] else []
|
||||
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite)
|
||||
else:
|
||||
|
|
|
@ -280,7 +280,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal_2.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user