diff --git a/codes/models/archs/rrdb_with_latent.py b/codes/models/archs/rrdb_with_latent.py index 5b0239a9..2be4ec18 100644 --- a/codes/models/archs/rrdb_with_latent.py +++ b/codes/models/archs/rrdb_with_latent.py @@ -134,13 +134,15 @@ class RRDBNetWithLatent(nn.Module): num_blocks=23, growth_channels=32, blocks_per_checkpoint=4, - scale=4): + scale=4, + bottom_latent_only=False): super(RRDBNetWithLatent, self).__init__() self.num_blocks = num_blocks self.blocks_per_checkpoint = blocks_per_checkpoint self.scale = scale self.in_channels = in_channels self.nf = mid_channels + self.bottom_latent_only = bottom_latent_only first_conv_stride = 1 if in_channels <= 4 else scale first_conv_ksize = 3 if first_conv_stride == 1 else 7 first_conv_padding = 1 if first_conv_stride == 1 else 3 @@ -172,6 +174,9 @@ class RRDBNetWithLatent(nn.Module): mults = [4, 2, 1] b, f, h, w = x.shape latent = [torch.randn((b, self.nf * m, h // m, w // m), dtype=torch.float, device=x.device) for m in mults] + if self.bottom_latent_only: + latent[1] = torch.zeros_like(latent[1]) + latent[2] = torch.zeros_like(latent[2]) latent = self.latent_encoder(latent) if latent_was_none is None: self.latent_mean = torch.mean(latent).detach().cpu() @@ -289,3 +294,80 @@ class LatentEstimator(nn.Module): 'latent_estimator_std': self.latent_std, 'latent_estimator_var': self.latent_var} + +class LatentEstimator2(nn.Module): + def __init__(self, in_nc, nf): + super(LatentEstimator2, self).__init__() + # [64, 128, 128] + self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) + self.bn0_1 = nn.BatchNorm2d(nf, affine=True) + # [64, 64, 64] + self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) + self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) + self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) + self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) + + # [128, 32, 32] + self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) + self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) + self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) + self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) + + # [256, 16, 16] + self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) + self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) + + # [256, 8, 8] + self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) + self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) + + # [256, 4, 4] + self.conv5_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) + self.bn5_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv5_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn5_1 = nn.BatchNorm2d(nf * 8, affine=True) + self.l = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, activation=True, norm=True, bias=True) + self.l2 = ConvGnLelu(nf * 4, nf * 4, kernel_size=1, activation=False, norm=False, bias=True) + + self.lrelu = nn.LeakyReLU(.2, inplace=True) + self.norm = nn.InstanceNorm2d(nf*4) + + def compute_body(self, x): + fea = self.lrelu(self.bn1_0(self.conv1_0(x))) + fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) + + fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) + fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) + + fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) + fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) + fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) + fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) + fea = self.lrelu(self.bn5_0(self.conv5_0(fea))) + fea = self.lrelu(self.bn5_1(self.conv5_1(fea))) + o3 = self.norm(self.l2(self.l(fea))) + + return F.interpolate(o3, scale_factor=4, mode="nearest") + + def forward(self, x): + fea = self.lrelu(self.conv0_0(x)) + fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) + o = checkpoint(self.compute_body, fea) + out = [o,\ + torch.zeros((o.shape[0],128,16,16), device=o.device),\ + torch.zeros((o.shape[0],64,32,32), device=o.device)] + self.latent_mean = torch.mean(out[-1]) + self.latent_std = torch.std(out[-1]) + self.latent_var = torch.var(out[-1]) + return out + + def get_debug_values(self, s, n): + return {'latent_estimator_mean': self.latent_mean, + 'latent_estimator_std': self.latent_std, + 'latent_estimator_var': self.latent_var} + diff --git a/codes/models/networks.py b/codes/models/networks.py index 12bbf919..b168bf96 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -20,7 +20,7 @@ import models.archs.rcan as rcan import models.archs.ChainedEmbeddingGen as chained from models.archs import srg2_classic from models.archs.pyramid_arch import BasicResamplingFlowNet -from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent +from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2 from models.archs.teco_resgen import TecoGen logger = logging.getLogger('base') @@ -122,10 +122,15 @@ def define_G(opt, net_key='network_G', scale=None): elif which_model == "rrdb_with_latent": netG = RRDBNetWithLatent(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], - blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale']) + blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], + scale=opt_net['scale'], + bottom_latent_only=opt_net['bottom_latent_only']) elif which_model == "latent_estimator": - overwrite = [1,2] if opt_net['only_base_level'] else [] - netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite) + if opt_net['version'] == 2: + netG = LatentEstimator2(in_nc=3, nf=opt_net['nf']) + else: + overwrite = [1,2] if opt_net['only_base_level'] else [] + netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG diff --git a/codes/train2.py b/codes/train2.py index cd608f97..13151cf7 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -280,7 +280,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal_2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)