Latent work checkpoint
This commit is contained in:
parent
fd6cdba88f
commit
62d3b6496b
|
@ -7,6 +7,7 @@ import torchvision
|
||||||
from torch.utils.checkpoint import checkpoint_sequential
|
from torch.utils.checkpoint import checkpoint_sequential
|
||||||
|
|
||||||
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||||
|
from models.archs.srg2_classic import Interpolate
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,10 +75,48 @@ class RRDBWithBypassAndLatent(nn.Module):
|
||||||
out = self.rdb3(out)
|
out = self.rdb3(out)
|
||||||
bypass = self.bypass(torch.cat([x, out], dim=1))
|
bypass = self.bypass(torch.cat([x, out], dim=1))
|
||||||
self.bypass_map = bypass.detach().clone()
|
self.bypass_map = bypass.detach().clone()
|
||||||
return out * 0.2 * bypass + x
|
residual = out * .2 * bypass
|
||||||
|
return residual + x, residual
|
||||||
|
|
||||||
|
|
||||||
class RRDBNetWithLatent(nn.Module):
|
class RRDBNetWithLatent(nn.Module):
|
||||||
|
# 8-layer MLP in the vein of StyleGAN.
|
||||||
|
def create_linear_latent_encoder(self, latent_size):
|
||||||
|
return nn.Sequential(nn.Linear(latent_size, latent_size),
|
||||||
|
nn.BatchNorm1d(latent_size),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||||
|
nn.Linear(latent_size, latent_size),
|
||||||
|
nn.BatchNorm1d(latent_size),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||||
|
nn.Linear(latent_size, latent_size),
|
||||||
|
nn.BatchNorm1d(latent_size),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||||
|
nn.Linear(latent_size, latent_size),
|
||||||
|
nn.BatchNorm1d(latent_size),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||||
|
nn.Linear(latent_size, latent_size),
|
||||||
|
nn.BatchNorm1d(latent_size),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||||
|
nn.Linear(latent_size, latent_size),
|
||||||
|
nn.BatchNorm1d(latent_size),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||||
|
nn.Linear(latent_size, latent_size),
|
||||||
|
nn.BatchNorm1d(latent_size),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||||
|
nn.Linear(latent_size, latent_size),
|
||||||
|
nn.BatchNorm1d(latent_size),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2, inplace=True))
|
||||||
|
|
||||||
|
# Creates a 2D latent by iterating through the provided latent_filters and doubling the
|
||||||
|
# image size each step.
|
||||||
|
def create_conv_latent_encoder(self, latent_filters):
|
||||||
|
layers = []
|
||||||
|
for i in range(len(latent_filters)-1):
|
||||||
|
layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i]))
|
||||||
|
layers.extend(Interpolate(2))
|
||||||
|
layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i+1]))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -144,9 +183,14 @@ class RRDBNetWithLatent(nn.Module):
|
||||||
default_init_weights(m, 0.1)
|
default_init_weights(m, 0.1)
|
||||||
|
|
||||||
def forward(self, x, latent=None, ref=None):
|
def forward(self, x, latent=None, ref=None):
|
||||||
|
latent_was_none = latent
|
||||||
if latent is None:
|
if latent is None:
|
||||||
latent = torch.randn((x.shape[0], self.latent_size), dtype=torch.float, device=x.device)
|
latent = torch.randn((x.shape[0], self.latent_size), dtype=torch.float, device=x.device)
|
||||||
latent = checkpoint(self.latent_encoder, latent)
|
latent = checkpoint(self.latent_encoder, latent)
|
||||||
|
if latent_was_none is None:
|
||||||
|
self.latent_mean = torch.mean(latent).detach().cpu()
|
||||||
|
self.latent_std = torch.std(latent).detach().cpu()
|
||||||
|
self.latent_var = torch.var(latent).detach().cpu()
|
||||||
if self.in_channels > 4:
|
if self.in_channels > 4:
|
||||||
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
||||||
if ref is None:
|
if ref is None:
|
||||||
|
@ -156,8 +200,12 @@ class RRDBNetWithLatent(nn.Module):
|
||||||
x_lg = x
|
x_lg = x
|
||||||
feat = self.conv_first(x_lg)
|
feat = self.conv_first(x_lg)
|
||||||
body_feat = feat
|
body_feat = feat
|
||||||
|
self.block_residual_means = []
|
||||||
|
self.block_residual_stds = []
|
||||||
for bl in self.body:
|
for bl in self.body:
|
||||||
body_feat = checkpoint(bl, body_feat, latent)
|
body_feat, residual = checkpoint(bl, body_feat, latent)
|
||||||
|
self.block_residual_means.append(torch.mean(residual).cpu())
|
||||||
|
self.block_residual_stds.append(torch.std(residual).cpu())
|
||||||
body_feat = self.conv_body(body_feat)
|
body_feat = self.conv_body(body_feat)
|
||||||
feat = feat + body_feat
|
feat = feat + body_feat
|
||||||
# upsample
|
# upsample
|
||||||
|
@ -175,6 +223,17 @@ class RRDBNetWithLatent(nn.Module):
|
||||||
for i, bm in enumerate(self.body):
|
for i, bm in enumerate(self.body):
|
||||||
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||||
|
|
||||||
|
def get_debug_values(self, s, n):
|
||||||
|
blk_stds, blk_means = {}, {}
|
||||||
|
for i, (s, m) in enumerate(zip(self.block_residual_stds, self.block_residual_means)):
|
||||||
|
blk_stds['block_%i' % (i+1,)] = s
|
||||||
|
blk_means['block_%i' % (i+1,)] = m
|
||||||
|
return {'encoded_latent_mean': self.latent_mean,
|
||||||
|
'encoded_latent_std': self.latent_std,
|
||||||
|
'encoded_latent_var': self.latent_var,
|
||||||
|
'blocks_mean': blk_means,
|
||||||
|
'blocks_std': blk_stds}
|
||||||
|
|
||||||
|
|
||||||
# Based heavily on the same VGG arch used for the discriminator.
|
# Based heavily on the same VGG arch used for the discriminator.
|
||||||
class LatentEstimator(nn.Module):
|
class LatentEstimator(nn.Module):
|
||||||
|
@ -212,6 +271,7 @@ class LatentEstimator(nn.Module):
|
||||||
|
|
||||||
self.linear1 = nn.Linear(int(final_nf * 4 * 4), latent_size*2)
|
self.linear1 = nn.Linear(int(final_nf * 4 * 4), latent_size*2)
|
||||||
self.linear2 = nn.Linear(latent_size*2, latent_size)
|
self.linear2 = nn.Linear(latent_size*2, latent_size)
|
||||||
|
self.tanh = nn.Tanh()
|
||||||
|
|
||||||
def compute_body(self, x):
|
def compute_body(self, x):
|
||||||
fea = self.lrelu(self.conv0_0(x))
|
fea = self.lrelu(self.conv0_0(x))
|
||||||
|
@ -236,5 +296,14 @@ class LatentEstimator(nn.Module):
|
||||||
fea = checkpoint(self.compute_body, x)
|
fea = checkpoint(self.compute_body, x)
|
||||||
fea = fea.contiguous().view(fea.size(0), -1)
|
fea = fea.contiguous().view(fea.size(0), -1)
|
||||||
fea = self.linear1(fea)
|
fea = self.linear1(fea)
|
||||||
out = self.linear2(fea)
|
out = self.tanh(self.linear2(fea))
|
||||||
|
self.latent_mean = torch.mean(out)
|
||||||
|
self.latent_std = torch.std(out)
|
||||||
|
self.latent_var = torch.var(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def get_debug_values(self, s, n):
|
||||||
|
return {'latent_estimator_mean': self.latent_mean,
|
||||||
|
'latent_estimator_std': self.latent_std,
|
||||||
|
'latent_estimator_var': self.latent_var}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user