More work on RRDB with latent
This commit is contained in:
parent
62d3b6496b
commit
4469d2e661
|
@ -46,13 +46,8 @@ class ResidualDenseBlock(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class RRDBWithBypassAndLatent(nn.Module):
|
class RRDBWithBypassAndLatent(nn.Module):
|
||||||
def __init__(self, mid_channels, growth_channels=32, latent_dim=256):
|
def __init__(self, mid_channels, growth_channels=32):
|
||||||
super(RRDBWithBypassAndLatent, self).__init__()
|
super(RRDBWithBypassAndLatent, self).__init__()
|
||||||
self.latent_process = nn.Sequential(nn.Linear(latent_dim, latent_dim//2, bias=False),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
nn.Linear(latent_dim//2, mid_channels, bias=False),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
nn.Linear(mid_channels, mid_channels, bias=True))
|
|
||||||
self.latent_join = nn.Sequential(ConvGnLelu(mid_channels*2, mid_channels*2, activation=True, norm=False, bias=False),
|
self.latent_join = nn.Sequential(ConvGnLelu(mid_channels*2, mid_channels*2, activation=True, norm=False, bias=False),
|
||||||
ConvGnLelu(mid_channels*2, mid_channels, activation=False, norm=False, bias=False))
|
ConvGnLelu(mid_channels*2, mid_channels, activation=False, norm=False, bias=False))
|
||||||
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
|
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||||
|
@ -63,12 +58,7 @@ class RRDBWithBypassAndLatent(nn.Module):
|
||||||
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
|
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
|
||||||
nn.Sigmoid())
|
nn.Sigmoid())
|
||||||
|
|
||||||
def forward(self, x, original_latent):
|
def forward(self, x, latent):
|
||||||
b, f, h, w = x.shape
|
|
||||||
latent = self.latent_process(original_latent)
|
|
||||||
b, l = latent.shape
|
|
||||||
latent = latent.view(b, l, 1, 1)
|
|
||||||
latent = latent.repeat(1, 1, h, w)
|
|
||||||
out = self.latent_join(torch.cat([x, latent], dim=1))
|
out = self.latent_join(torch.cat([x, latent], dim=1))
|
||||||
out = self.rdb1(out, x)
|
out = self.rdb1(out, x)
|
||||||
out = self.rdb2(out)
|
out = self.rdb2(out)
|
||||||
|
@ -79,6 +69,31 @@ class RRDBWithBypassAndLatent(nn.Module):
|
||||||
return residual + x, residual
|
return residual + x, residual
|
||||||
|
|
||||||
|
|
||||||
|
class ConvLatentEncoder(nn.Module):
|
||||||
|
def __init__(self, nf):
|
||||||
|
super(ConvLatentEncoder, self).__init__()
|
||||||
|
latent_filters = [nf * 4, nf * 2, nf]
|
||||||
|
layers = []
|
||||||
|
for i in range(len(latent_filters)-1):
|
||||||
|
layers.append(nn.Sequential(
|
||||||
|
ConvGnLelu(latent_filters[i], latent_filters[i], kernel_size=1, activation=True, bias=False, norm=True),
|
||||||
|
Interpolate(2),
|
||||||
|
ConvGnLelu(latent_filters[i], latent_filters[i+1], kernel_size=1, activation=True, bias=False, norm=True)))
|
||||||
|
self.final = nn.Sequential(
|
||||||
|
ConvGnLelu(nf, nf, kernel_size=1, activation=True, bias=True, norm=True),
|
||||||
|
ConvGnLelu(nf, nf, kernel_size=1, activation=False, bias=True, norm=False))
|
||||||
|
self.layers = nn.ModuleList(layers)
|
||||||
|
|
||||||
|
def forward(self, latents):
|
||||||
|
assert len(latents) == 3
|
||||||
|
out = torch.zeros_like(latents[0])
|
||||||
|
for i in range(2):
|
||||||
|
out = out + latents[i]
|
||||||
|
out = self.layers[i](out)
|
||||||
|
out = out + latents[2]
|
||||||
|
return self.final(out)
|
||||||
|
|
||||||
|
|
||||||
class RRDBNetWithLatent(nn.Module):
|
class RRDBNetWithLatent(nn.Module):
|
||||||
# 8-layer MLP in the vein of StyleGAN.
|
# 8-layer MLP in the vein of StyleGAN.
|
||||||
def create_linear_latent_encoder(self, latent_size):
|
def create_linear_latent_encoder(self, latent_size):
|
||||||
|
@ -110,12 +125,7 @@ class RRDBNetWithLatent(nn.Module):
|
||||||
# Creates a 2D latent by iterating through the provided latent_filters and doubling the
|
# Creates a 2D latent by iterating through the provided latent_filters and doubling the
|
||||||
# image size each step.
|
# image size each step.
|
||||||
def create_conv_latent_encoder(self, latent_filters):
|
def create_conv_latent_encoder(self, latent_filters):
|
||||||
layers = []
|
return ConvLatentEncoder(latent_filters)
|
||||||
for i in range(len(latent_filters)-1):
|
|
||||||
layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i]))
|
|
||||||
layers.extend(Interpolate(2))
|
|
||||||
layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i+1]))
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels,
|
in_channels,
|
||||||
|
@ -124,14 +134,13 @@ class RRDBNetWithLatent(nn.Module):
|
||||||
num_blocks=23,
|
num_blocks=23,
|
||||||
growth_channels=32,
|
growth_channels=32,
|
||||||
blocks_per_checkpoint=4,
|
blocks_per_checkpoint=4,
|
||||||
scale=4,
|
scale=4):
|
||||||
latent_size=256):
|
|
||||||
super(RRDBNetWithLatent, self).__init__()
|
super(RRDBNetWithLatent, self).__init__()
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.blocks_per_checkpoint = blocks_per_checkpoint
|
self.blocks_per_checkpoint = blocks_per_checkpoint
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.latent_size = latent_size
|
self.nf = mid_channels
|
||||||
first_conv_stride = 1 if in_channels <= 4 else scale
|
first_conv_stride = 1 if in_channels <= 4 else scale
|
||||||
first_conv_ksize = 3 if first_conv_stride == 1 else 7
|
first_conv_ksize = 3 if first_conv_stride == 1 else 7
|
||||||
first_conv_padding = 1 if first_conv_stride == 1 else 3
|
first_conv_padding = 1 if first_conv_stride == 1 else 3
|
||||||
|
@ -140,8 +149,7 @@ class RRDBNetWithLatent(nn.Module):
|
||||||
RRDBWithBypassAndLatent,
|
RRDBWithBypassAndLatent,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
mid_channels=mid_channels,
|
mid_channels=mid_channels,
|
||||||
growth_channels=growth_channels,
|
growth_channels=growth_channels)
|
||||||
latent_dim=latent_size)
|
|
||||||
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||||
# upsample
|
# upsample
|
||||||
self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||||
|
@ -150,31 +158,7 @@ class RRDBNetWithLatent(nn.Module):
|
||||||
self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
|
self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
# 8-layer MLP in the vein of StyleGAN.
|
self.latent_encoder = self.create_conv_latent_encoder(mid_channels)
|
||||||
self.latent_encoder = nn.Sequential(nn.Linear(latent_size, latent_size),
|
|
||||||
nn.BatchNorm1d(latent_size),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
nn.Linear(latent_size, latent_size),
|
|
||||||
nn.BatchNorm1d(latent_size),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
nn.Linear(latent_size, latent_size),
|
|
||||||
nn.BatchNorm1d(latent_size),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
nn.Linear(latent_size, latent_size),
|
|
||||||
nn.BatchNorm1d(latent_size),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
nn.Linear(latent_size, latent_size),
|
|
||||||
nn.BatchNorm1d(latent_size),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
nn.Linear(latent_size, latent_size),
|
|
||||||
nn.BatchNorm1d(latent_size),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
nn.Linear(latent_size, latent_size),
|
|
||||||
nn.BatchNorm1d(latent_size),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
nn.Linear(latent_size, latent_size),
|
|
||||||
nn.BatchNorm1d(latent_size),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True))
|
|
||||||
|
|
||||||
for m in [
|
for m in [
|
||||||
self.conv_first, self.conv_body, self.conv_up1,
|
self.conv_first, self.conv_body, self.conv_up1,
|
||||||
|
@ -185,8 +169,10 @@ class RRDBNetWithLatent(nn.Module):
|
||||||
def forward(self, x, latent=None, ref=None):
|
def forward(self, x, latent=None, ref=None):
|
||||||
latent_was_none = latent
|
latent_was_none = latent
|
||||||
if latent is None:
|
if latent is None:
|
||||||
latent = torch.randn((x.shape[0], self.latent_size), dtype=torch.float, device=x.device)
|
mults = [4, 2, 1]
|
||||||
latent = checkpoint(self.latent_encoder, latent)
|
b, f, h, w = x.shape
|
||||||
|
latent = [torch.randn((b, self.nf * m, h // m, w // m), dtype=torch.float, device=x.device) for m in mults]
|
||||||
|
latent = self.latent_encoder(latent)
|
||||||
if latent_was_none is None:
|
if latent_was_none is None:
|
||||||
self.latent_mean = torch.mean(latent).detach().cpu()
|
self.latent_mean = torch.mean(latent).detach().cpu()
|
||||||
self.latent_std = torch.std(latent).detach().cpu()
|
self.latent_std = torch.std(latent).detach().cpu()
|
||||||
|
@ -238,7 +224,7 @@ class RRDBNetWithLatent(nn.Module):
|
||||||
# Based heavily on the same VGG arch used for the discriminator.
|
# Based heavily on the same VGG arch used for the discriminator.
|
||||||
class LatentEstimator(nn.Module):
|
class LatentEstimator(nn.Module):
|
||||||
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||||
def __init__(self, in_nc, nf, latent_size=256):
|
def __init__(self, in_nc, nf):
|
||||||
super(LatentEstimator, self).__init__()
|
super(LatentEstimator, self).__init__()
|
||||||
# [64, 128, 128]
|
# [64, 128, 128]
|
||||||
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
|
@ -249,57 +235,50 @@ class LatentEstimator(nn.Module):
|
||||||
self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||||
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||||
self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||||
|
self.d1p1 = ConvGnLelu(nf * 2, nf, kernel_size=1, activation=True, norm=True, bias=True)
|
||||||
|
self.d1p2 = ConvGnLelu(nf, nf, kernel_size=1, activation=False, norm=False, bias=True)
|
||||||
|
|
||||||
# [128, 32, 32]
|
# [128, 32, 32]
|
||||||
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||||
self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||||
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||||
self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
|
self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||||
|
self.d2p1 = ConvGnLelu(nf * 4, nf * 2, kernel_size=1, activation=True, norm=True, bias=True)
|
||||||
|
self.d2p2 = ConvGnLelu(nf * 2, nf * 2, kernel_size=1, activation=False, norm=False, bias=True)
|
||||||
|
|
||||||
# [256, 16, 16]
|
# [256, 16, 16]
|
||||||
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
|
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
|
||||||
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||||
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||||
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||||
# [512, 8, 8]
|
self.d3p1 = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, activation=True, norm=True, bias=True)
|
||||||
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
self.d3p2 = ConvGnLelu(nf * 4, nf * 4, kernel_size=1, activation=False, norm=False, bias=True)
|
||||||
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
|
||||||
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
|
||||||
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
|
||||||
final_nf = nf * 8
|
|
||||||
|
|
||||||
# activation function
|
self.lrelu = nn.LeakyReLU(.2, inplace=True)
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
self.linear1 = nn.Linear(int(final_nf * 4 * 4), latent_size*2)
|
|
||||||
self.linear2 = nn.Linear(latent_size*2, latent_size)
|
|
||||||
self.tanh = nn.Tanh()
|
self.tanh = nn.Tanh()
|
||||||
|
|
||||||
def compute_body(self, x):
|
def compute_body(self, x):
|
||||||
fea = self.lrelu(self.conv0_0(x))
|
fea = self.lrelu(self.bn1_0(self.conv1_0(x)))
|
||||||
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
|
||||||
|
|
||||||
#fea = torch.cat([fea, skip_med], dim=1)
|
|
||||||
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
|
|
||||||
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
|
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
|
||||||
|
o1 = self.tanh(self.d1p2(self.d1p1(fea)))
|
||||||
|
|
||||||
#fea = torch.cat([fea, skip_lo], dim=1)
|
|
||||||
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
|
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
|
||||||
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
|
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
|
||||||
|
o2 = self.tanh(self.d2p2(self.d2p1(fea)))
|
||||||
|
|
||||||
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
|
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
|
||||||
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
|
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
|
||||||
|
o3 = self.tanh(self.d3p2(self.d3p1(fea)))
|
||||||
|
|
||||||
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
|
return o3, o2, o1
|
||||||
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
|
|
||||||
return fea
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
fea = checkpoint(self.compute_body, x)
|
fea = self.lrelu(self.conv0_0(x))
|
||||||
fea = fea.contiguous().view(fea.size(0), -1)
|
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
||||||
fea = self.linear1(fea)
|
out = list(checkpoint(self.compute_body, fea))
|
||||||
out = self.tanh(self.linear2(fea))
|
self.latent_mean = torch.mean(out[-1])
|
||||||
self.latent_mean = torch.mean(out)
|
self.latent_std = torch.std(out[-1])
|
||||||
self.latent_std = torch.std(out)
|
self.latent_var = torch.var(out[-1])
|
||||||
self.latent_var = torch.var(out)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def get_debug_values(self, s, n):
|
def get_debug_values(self, s, n):
|
||||||
|
|
|
@ -157,6 +157,8 @@ class Trainer:
|
||||||
print("Data fetch: %f" % (time() - _t))
|
print("Data fetch: %f" % (time() - _t))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
|
#self.tb_logger.add_graph(self.model.netsG['generator'].module, input_to_model=torch.randn((1,3,32,32), device='cuda:0'))
|
||||||
|
|
||||||
opt = self.opt
|
opt = self.opt
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
#### update learning rate
|
#### update learning rate
|
||||||
|
@ -278,7 +280,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_with_flow.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user