Latent work checkpoint

This commit is contained in:
James Betker 2020-11-05 13:31:34 -07:00
parent fd6cdba88f
commit 62d3b6496b

View File

@ -7,6 +7,7 @@ import torchvision
from torch.utils.checkpoint import checkpoint_sequential
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
from models.archs.srg2_classic import Interpolate
from utils.util import checkpoint
@ -74,10 +75,48 @@ class RRDBWithBypassAndLatent(nn.Module):
out = self.rdb3(out)
bypass = self.bypass(torch.cat([x, out], dim=1))
self.bypass_map = bypass.detach().clone()
return out * 0.2 * bypass + x
residual = out * .2 * bypass
return residual + x, residual
class RRDBNetWithLatent(nn.Module):
# 8-layer MLP in the vein of StyleGAN.
def create_linear_latent_encoder(self, latent_size):
return nn.Sequential(nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True))
# Creates a 2D latent by iterating through the provided latent_filters and doubling the
# image size each step.
def create_conv_latent_encoder(self, latent_filters):
layers = []
for i in range(len(latent_filters)-1):
layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i]))
layers.extend(Interpolate(2))
layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i+1]))
return nn.Sequential(*layers)
def __init__(self,
in_channels,
out_channels,
@ -144,9 +183,14 @@ class RRDBNetWithLatent(nn.Module):
default_init_weights(m, 0.1)
def forward(self, x, latent=None, ref=None):
latent_was_none = latent
if latent is None:
latent = torch.randn((x.shape[0], self.latent_size), dtype=torch.float, device=x.device)
latent = checkpoint(self.latent_encoder, latent)
if latent_was_none is None:
self.latent_mean = torch.mean(latent).detach().cpu()
self.latent_std = torch.std(latent).detach().cpu()
self.latent_var = torch.var(latent).detach().cpu()
if self.in_channels > 4:
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
if ref is None:
@ -156,8 +200,12 @@ class RRDBNetWithLatent(nn.Module):
x_lg = x
feat = self.conv_first(x_lg)
body_feat = feat
self.block_residual_means = []
self.block_residual_stds = []
for bl in self.body:
body_feat = checkpoint(bl, body_feat, latent)
body_feat, residual = checkpoint(bl, body_feat, latent)
self.block_residual_means.append(torch.mean(residual).cpu())
self.block_residual_stds.append(torch.std(residual).cpu())
body_feat = self.conv_body(body_feat)
feat = feat + body_feat
# upsample
@ -175,6 +223,17 @@ class RRDBNetWithLatent(nn.Module):
for i, bm in enumerate(self.body):
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
def get_debug_values(self, s, n):
blk_stds, blk_means = {}, {}
for i, (s, m) in enumerate(zip(self.block_residual_stds, self.block_residual_means)):
blk_stds['block_%i' % (i+1,)] = s
blk_means['block_%i' % (i+1,)] = m
return {'encoded_latent_mean': self.latent_mean,
'encoded_latent_std': self.latent_std,
'encoded_latent_var': self.latent_var,
'blocks_mean': blk_means,
'blocks_std': blk_stds}
# Based heavily on the same VGG arch used for the discriminator.
class LatentEstimator(nn.Module):
@ -212,6 +271,7 @@ class LatentEstimator(nn.Module):
self.linear1 = nn.Linear(int(final_nf * 4 * 4), latent_size*2)
self.linear2 = nn.Linear(latent_size*2, latent_size)
self.tanh = nn.Tanh()
def compute_body(self, x):
fea = self.lrelu(self.conv0_0(x))
@ -236,5 +296,14 @@ class LatentEstimator(nn.Module):
fea = checkpoint(self.compute_body, x)
fea = fea.contiguous().view(fea.size(0), -1)
fea = self.linear1(fea)
out = self.linear2(fea)
return out
out = self.tanh(self.linear2(fea))
self.latent_mean = torch.mean(out)
self.latent_std = torch.std(out)
self.latent_var = torch.var(out)
return out
def get_debug_values(self, s, n):
return {'latent_estimator_mean': self.latent_mean,
'latent_estimator_std': self.latent_std,
'latent_estimator_var': self.latent_var}