More work on RRDB with latent

This commit is contained in:
James Betker 2020-11-05 22:13:05 -07:00
parent 62d3b6496b
commit 4469d2e661
2 changed files with 60 additions and 79 deletions

View File

@ -46,13 +46,8 @@ class ResidualDenseBlock(nn.Module):
class RRDBWithBypassAndLatent(nn.Module):
def __init__(self, mid_channels, growth_channels=32, latent_dim=256):
def __init__(self, mid_channels, growth_channels=32):
super(RRDBWithBypassAndLatent, self).__init__()
self.latent_process = nn.Sequential(nn.Linear(latent_dim, latent_dim//2, bias=False),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_dim//2, mid_channels, bias=False),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(mid_channels, mid_channels, bias=True))
self.latent_join = nn.Sequential(ConvGnLelu(mid_channels*2, mid_channels*2, activation=True, norm=False, bias=False),
ConvGnLelu(mid_channels*2, mid_channels, activation=False, norm=False, bias=False))
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
@ -63,12 +58,7 @@ class RRDBWithBypassAndLatent(nn.Module):
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
nn.Sigmoid())
def forward(self, x, original_latent):
b, f, h, w = x.shape
latent = self.latent_process(original_latent)
b, l = latent.shape
latent = latent.view(b, l, 1, 1)
latent = latent.repeat(1, 1, h, w)
def forward(self, x, latent):
out = self.latent_join(torch.cat([x, latent], dim=1))
out = self.rdb1(out, x)
out = self.rdb2(out)
@ -79,6 +69,31 @@ class RRDBWithBypassAndLatent(nn.Module):
return residual + x, residual
class ConvLatentEncoder(nn.Module):
def __init__(self, nf):
super(ConvLatentEncoder, self).__init__()
latent_filters = [nf * 4, nf * 2, nf]
layers = []
for i in range(len(latent_filters)-1):
layers.append(nn.Sequential(
ConvGnLelu(latent_filters[i], latent_filters[i], kernel_size=1, activation=True, bias=False, norm=True),
Interpolate(2),
ConvGnLelu(latent_filters[i], latent_filters[i+1], kernel_size=1, activation=True, bias=False, norm=True)))
self.final = nn.Sequential(
ConvGnLelu(nf, nf, kernel_size=1, activation=True, bias=True, norm=True),
ConvGnLelu(nf, nf, kernel_size=1, activation=False, bias=True, norm=False))
self.layers = nn.ModuleList(layers)
def forward(self, latents):
assert len(latents) == 3
out = torch.zeros_like(latents[0])
for i in range(2):
out = out + latents[i]
out = self.layers[i](out)
out = out + latents[2]
return self.final(out)
class RRDBNetWithLatent(nn.Module):
# 8-layer MLP in the vein of StyleGAN.
def create_linear_latent_encoder(self, latent_size):
@ -110,12 +125,7 @@ class RRDBNetWithLatent(nn.Module):
# Creates a 2D latent by iterating through the provided latent_filters and doubling the
# image size each step.
def create_conv_latent_encoder(self, latent_filters):
layers = []
for i in range(len(latent_filters)-1):
layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i]))
layers.extend(Interpolate(2))
layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i+1]))
return nn.Sequential(*layers)
return ConvLatentEncoder(latent_filters)
def __init__(self,
in_channels,
@ -124,14 +134,13 @@ class RRDBNetWithLatent(nn.Module):
num_blocks=23,
growth_channels=32,
blocks_per_checkpoint=4,
scale=4,
latent_size=256):
scale=4):
super(RRDBNetWithLatent, self).__init__()
self.num_blocks = num_blocks
self.blocks_per_checkpoint = blocks_per_checkpoint
self.scale = scale
self.in_channels = in_channels
self.latent_size = latent_size
self.nf = mid_channels
first_conv_stride = 1 if in_channels <= 4 else scale
first_conv_ksize = 3 if first_conv_stride == 1 else 7
first_conv_padding = 1 if first_conv_stride == 1 else 3
@ -140,8 +149,7 @@ class RRDBNetWithLatent(nn.Module):
RRDBWithBypassAndLatent,
num_blocks,
mid_channels=mid_channels,
growth_channels=growth_channels,
latent_dim=latent_size)
growth_channels=growth_channels)
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
@ -150,31 +158,7 @@ class RRDBNetWithLatent(nn.Module):
self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# 8-layer MLP in the vein of StyleGAN.
self.latent_encoder = nn.Sequential(nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True))
self.latent_encoder = self.create_conv_latent_encoder(mid_channels)
for m in [
self.conv_first, self.conv_body, self.conv_up1,
@ -185,8 +169,10 @@ class RRDBNetWithLatent(nn.Module):
def forward(self, x, latent=None, ref=None):
latent_was_none = latent
if latent is None:
latent = torch.randn((x.shape[0], self.latent_size), dtype=torch.float, device=x.device)
latent = checkpoint(self.latent_encoder, latent)
mults = [4, 2, 1]
b, f, h, w = x.shape
latent = [torch.randn((b, self.nf * m, h // m, w // m), dtype=torch.float, device=x.device) for m in mults]
latent = self.latent_encoder(latent)
if latent_was_none is None:
self.latent_mean = torch.mean(latent).detach().cpu()
self.latent_std = torch.std(latent).detach().cpu()
@ -238,7 +224,7 @@ class RRDBNetWithLatent(nn.Module):
# Based heavily on the same VGG arch used for the discriminator.
class LatentEstimator(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, in_nc, nf, latent_size=256):
def __init__(self, in_nc, nf):
super(LatentEstimator, self).__init__()
# [64, 128, 128]
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
@ -249,57 +235,50 @@ class LatentEstimator(nn.Module):
self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
self.d1p1 = ConvGnLelu(nf * 2, nf, kernel_size=1, activation=True, norm=True, bias=True)
self.d1p2 = ConvGnLelu(nf, nf, kernel_size=1, activation=False, norm=False, bias=True)
# [128, 32, 32]
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
self.d2p1 = ConvGnLelu(nf * 4, nf * 2, kernel_size=1, activation=True, norm=True, bias=True)
self.d2p2 = ConvGnLelu(nf * 2, nf * 2, kernel_size=1, activation=False, norm=False, bias=True)
# [256, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
# [512, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
final_nf = nf * 8
self.d3p1 = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, activation=True, norm=True, bias=True)
self.d3p2 = ConvGnLelu(nf * 4, nf * 4, kernel_size=1, activation=False, norm=False, bias=True)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.linear1 = nn.Linear(int(final_nf * 4 * 4), latent_size*2)
self.linear2 = nn.Linear(latent_size*2, latent_size)
self.lrelu = nn.LeakyReLU(.2, inplace=True)
self.tanh = nn.Tanh()
def compute_body(self, x):
fea = self.lrelu(self.conv0_0(x))
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
#fea = torch.cat([fea, skip_med], dim=1)
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
fea = self.lrelu(self.bn1_0(self.conv1_0(x)))
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
o1 = self.tanh(self.d1p2(self.d1p1(fea)))
#fea = torch.cat([fea, skip_lo], dim=1)
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
o2 = self.tanh(self.d2p2(self.d2p1(fea)))
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
o3 = self.tanh(self.d3p2(self.d3p1(fea)))
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
return fea
return o3, o2, o1
def forward(self, x):
fea = checkpoint(self.compute_body, x)
fea = fea.contiguous().view(fea.size(0), -1)
fea = self.linear1(fea)
out = self.tanh(self.linear2(fea))
self.latent_mean = torch.mean(out)
self.latent_std = torch.std(out)
self.latent_var = torch.var(out)
fea = self.lrelu(self.conv0_0(x))
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
out = list(checkpoint(self.compute_body, fea))
self.latent_mean = torch.mean(out[-1])
self.latent_std = torch.std(out[-1])
self.latent_var = torch.var(out[-1])
return out
def get_debug_values(self, s, n):

View File

@ -157,6 +157,8 @@ class Trainer:
print("Data fetch: %f" % (time() - _t))
_t = time()
#self.tb_logger.add_graph(self.model.netsG['generator'].module, input_to_model=torch.randn((1,3,32,32), device='cuda:0'))
opt = self.opt
self.current_step += 1
#### update learning rate
@ -278,7 +280,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_with_flow.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)