From 62d3b6496b42a7320e501047b6343dd47ac047f6 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 5 Nov 2020 13:31:34 -0700 Subject: [PATCH] Latent work checkpoint --- codes/models/archs/rrdb_with_latent.py | 77 ++++++++++++++++++++++++-- 1 file changed, 73 insertions(+), 4 deletions(-) diff --git a/codes/models/archs/rrdb_with_latent.py b/codes/models/archs/rrdb_with_latent.py index 7e988129..952d4271 100644 --- a/codes/models/archs/rrdb_with_latent.py +++ b/codes/models/archs/rrdb_with_latent.py @@ -7,6 +7,7 @@ import torchvision from torch.utils.checkpoint import checkpoint_sequential from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu +from models.archs.srg2_classic import Interpolate from utils.util import checkpoint @@ -74,10 +75,48 @@ class RRDBWithBypassAndLatent(nn.Module): out = self.rdb3(out) bypass = self.bypass(torch.cat([x, out], dim=1)) self.bypass_map = bypass.detach().clone() - return out * 0.2 * bypass + x + residual = out * .2 * bypass + return residual + x, residual class RRDBNetWithLatent(nn.Module): + # 8-layer MLP in the vein of StyleGAN. + def create_linear_latent_encoder(self, latent_size): + return nn.Sequential(nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True)) + + # Creates a 2D latent by iterating through the provided latent_filters and doubling the + # image size each step. + def create_conv_latent_encoder(self, latent_filters): + layers = [] + for i in range(len(latent_filters)-1): + layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i])) + layers.extend(Interpolate(2)) + layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i+1])) + return nn.Sequential(*layers) + def __init__(self, in_channels, out_channels, @@ -144,9 +183,14 @@ class RRDBNetWithLatent(nn.Module): default_init_weights(m, 0.1) def forward(self, x, latent=None, ref=None): + latent_was_none = latent if latent is None: latent = torch.randn((x.shape[0], self.latent_size), dtype=torch.float, device=x.device) latent = checkpoint(self.latent_encoder, latent) + if latent_was_none is None: + self.latent_mean = torch.mean(latent).detach().cpu() + self.latent_std = torch.std(latent).detach().cpu() + self.latent_var = torch.var(latent).detach().cpu() if self.in_channels > 4: x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic") if ref is None: @@ -156,8 +200,12 @@ class RRDBNetWithLatent(nn.Module): x_lg = x feat = self.conv_first(x_lg) body_feat = feat + self.block_residual_means = [] + self.block_residual_stds = [] for bl in self.body: - body_feat = checkpoint(bl, body_feat, latent) + body_feat, residual = checkpoint(bl, body_feat, latent) + self.block_residual_means.append(torch.mean(residual).cpu()) + self.block_residual_stds.append(torch.std(residual).cpu()) body_feat = self.conv_body(body_feat) feat = feat + body_feat # upsample @@ -175,6 +223,17 @@ class RRDBNetWithLatent(nn.Module): for i, bm in enumerate(self.body): torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1))) + def get_debug_values(self, s, n): + blk_stds, blk_means = {}, {} + for i, (s, m) in enumerate(zip(self.block_residual_stds, self.block_residual_means)): + blk_stds['block_%i' % (i+1,)] = s + blk_means['block_%i' % (i+1,)] = m + return {'encoded_latent_mean': self.latent_mean, + 'encoded_latent_std': self.latent_std, + 'encoded_latent_var': self.latent_var, + 'blocks_mean': blk_means, + 'blocks_std': blk_stds} + # Based heavily on the same VGG arch used for the discriminator. class LatentEstimator(nn.Module): @@ -212,6 +271,7 @@ class LatentEstimator(nn.Module): self.linear1 = nn.Linear(int(final_nf * 4 * 4), latent_size*2) self.linear2 = nn.Linear(latent_size*2, latent_size) + self.tanh = nn.Tanh() def compute_body(self, x): fea = self.lrelu(self.conv0_0(x)) @@ -236,5 +296,14 @@ class LatentEstimator(nn.Module): fea = checkpoint(self.compute_body, x) fea = fea.contiguous().view(fea.size(0), -1) fea = self.linear1(fea) - out = self.linear2(fea) - return out \ No newline at end of file + out = self.tanh(self.linear2(fea)) + self.latent_mean = torch.mean(out) + self.latent_std = torch.std(out) + self.latent_var = torch.var(out) + return out + + def get_debug_values(self, s, n): + return {'latent_estimator_mean': self.latent_mean, + 'latent_estimator_std': self.latent_std, + 'latent_estimator_var': self.latent_var} +