Latent work checkpoint
This commit is contained in:
parent
fd6cdba88f
commit
62d3b6496b
|
@ -7,6 +7,7 @@ import torchvision
|
|||
from torch.utils.checkpoint import checkpoint_sequential
|
||||
|
||||
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
from models.archs.srg2_classic import Interpolate
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
|
@ -74,10 +75,48 @@ class RRDBWithBypassAndLatent(nn.Module):
|
|||
out = self.rdb3(out)
|
||||
bypass = self.bypass(torch.cat([x, out], dim=1))
|
||||
self.bypass_map = bypass.detach().clone()
|
||||
return out * 0.2 * bypass + x
|
||||
residual = out * .2 * bypass
|
||||
return residual + x, residual
|
||||
|
||||
|
||||
class RRDBNetWithLatent(nn.Module):
|
||||
# 8-layer MLP in the vein of StyleGAN.
|
||||
def create_linear_latent_encoder(self, latent_size):
|
||||
return nn.Sequential(nn.Linear(latent_size, latent_size),
|
||||
nn.BatchNorm1d(latent_size),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Linear(latent_size, latent_size),
|
||||
nn.BatchNorm1d(latent_size),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Linear(latent_size, latent_size),
|
||||
nn.BatchNorm1d(latent_size),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Linear(latent_size, latent_size),
|
||||
nn.BatchNorm1d(latent_size),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Linear(latent_size, latent_size),
|
||||
nn.BatchNorm1d(latent_size),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Linear(latent_size, latent_size),
|
||||
nn.BatchNorm1d(latent_size),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Linear(latent_size, latent_size),
|
||||
nn.BatchNorm1d(latent_size),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Linear(latent_size, latent_size),
|
||||
nn.BatchNorm1d(latent_size),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True))
|
||||
|
||||
# Creates a 2D latent by iterating through the provided latent_filters and doubling the
|
||||
# image size each step.
|
||||
def create_conv_latent_encoder(self, latent_filters):
|
||||
layers = []
|
||||
for i in range(len(latent_filters)-1):
|
||||
layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i]))
|
||||
layers.extend(Interpolate(2))
|
||||
layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i+1]))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
@ -144,9 +183,14 @@ class RRDBNetWithLatent(nn.Module):
|
|||
default_init_weights(m, 0.1)
|
||||
|
||||
def forward(self, x, latent=None, ref=None):
|
||||
latent_was_none = latent
|
||||
if latent is None:
|
||||
latent = torch.randn((x.shape[0], self.latent_size), dtype=torch.float, device=x.device)
|
||||
latent = checkpoint(self.latent_encoder, latent)
|
||||
if latent_was_none is None:
|
||||
self.latent_mean = torch.mean(latent).detach().cpu()
|
||||
self.latent_std = torch.std(latent).detach().cpu()
|
||||
self.latent_var = torch.var(latent).detach().cpu()
|
||||
if self.in_channels > 4:
|
||||
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
||||
if ref is None:
|
||||
|
@ -156,8 +200,12 @@ class RRDBNetWithLatent(nn.Module):
|
|||
x_lg = x
|
||||
feat = self.conv_first(x_lg)
|
||||
body_feat = feat
|
||||
self.block_residual_means = []
|
||||
self.block_residual_stds = []
|
||||
for bl in self.body:
|
||||
body_feat = checkpoint(bl, body_feat, latent)
|
||||
body_feat, residual = checkpoint(bl, body_feat, latent)
|
||||
self.block_residual_means.append(torch.mean(residual).cpu())
|
||||
self.block_residual_stds.append(torch.std(residual).cpu())
|
||||
body_feat = self.conv_body(body_feat)
|
||||
feat = feat + body_feat
|
||||
# upsample
|
||||
|
@ -175,6 +223,17 @@ class RRDBNetWithLatent(nn.Module):
|
|||
for i, bm in enumerate(self.body):
|
||||
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||
|
||||
def get_debug_values(self, s, n):
|
||||
blk_stds, blk_means = {}, {}
|
||||
for i, (s, m) in enumerate(zip(self.block_residual_stds, self.block_residual_means)):
|
||||
blk_stds['block_%i' % (i+1,)] = s
|
||||
blk_means['block_%i' % (i+1,)] = m
|
||||
return {'encoded_latent_mean': self.latent_mean,
|
||||
'encoded_latent_std': self.latent_std,
|
||||
'encoded_latent_var': self.latent_var,
|
||||
'blocks_mean': blk_means,
|
||||
'blocks_std': blk_stds}
|
||||
|
||||
|
||||
# Based heavily on the same VGG arch used for the discriminator.
|
||||
class LatentEstimator(nn.Module):
|
||||
|
@ -212,6 +271,7 @@ class LatentEstimator(nn.Module):
|
|||
|
||||
self.linear1 = nn.Linear(int(final_nf * 4 * 4), latent_size*2)
|
||||
self.linear2 = nn.Linear(latent_size*2, latent_size)
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
def compute_body(self, x):
|
||||
fea = self.lrelu(self.conv0_0(x))
|
||||
|
@ -236,5 +296,14 @@ class LatentEstimator(nn.Module):
|
|||
fea = checkpoint(self.compute_body, x)
|
||||
fea = fea.contiguous().view(fea.size(0), -1)
|
||||
fea = self.linear1(fea)
|
||||
out = self.linear2(fea)
|
||||
return out
|
||||
out = self.tanh(self.linear2(fea))
|
||||
self.latent_mean = torch.mean(out)
|
||||
self.latent_std = torch.std(out)
|
||||
self.latent_var = torch.var(out)
|
||||
return out
|
||||
|
||||
def get_debug_values(self, s, n):
|
||||
return {'latent_estimator_mean': self.latent_mean,
|
||||
'latent_estimator_std': self.latent_std,
|
||||
'latent_estimator_var': self.latent_var}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user