More work on RRDB with latent

This commit is contained in:
James Betker 2020-11-05 22:13:05 -07:00
parent 62d3b6496b
commit 4469d2e661
2 changed files with 60 additions and 79 deletions

View File

@ -46,13 +46,8 @@ class ResidualDenseBlock(nn.Module):
class RRDBWithBypassAndLatent(nn.Module): class RRDBWithBypassAndLatent(nn.Module):
def __init__(self, mid_channels, growth_channels=32, latent_dim=256): def __init__(self, mid_channels, growth_channels=32):
super(RRDBWithBypassAndLatent, self).__init__() super(RRDBWithBypassAndLatent, self).__init__()
self.latent_process = nn.Sequential(nn.Linear(latent_dim, latent_dim//2, bias=False),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_dim//2, mid_channels, bias=False),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(mid_channels, mid_channels, bias=True))
self.latent_join = nn.Sequential(ConvGnLelu(mid_channels*2, mid_channels*2, activation=True, norm=False, bias=False), self.latent_join = nn.Sequential(ConvGnLelu(mid_channels*2, mid_channels*2, activation=True, norm=False, bias=False),
ConvGnLelu(mid_channels*2, mid_channels, activation=False, norm=False, bias=False)) ConvGnLelu(mid_channels*2, mid_channels, activation=False, norm=False, bias=False))
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
@ -63,12 +58,7 @@ class RRDBWithBypassAndLatent(nn.Module):
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False), ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
nn.Sigmoid()) nn.Sigmoid())
def forward(self, x, original_latent): def forward(self, x, latent):
b, f, h, w = x.shape
latent = self.latent_process(original_latent)
b, l = latent.shape
latent = latent.view(b, l, 1, 1)
latent = latent.repeat(1, 1, h, w)
out = self.latent_join(torch.cat([x, latent], dim=1)) out = self.latent_join(torch.cat([x, latent], dim=1))
out = self.rdb1(out, x) out = self.rdb1(out, x)
out = self.rdb2(out) out = self.rdb2(out)
@ -79,6 +69,31 @@ class RRDBWithBypassAndLatent(nn.Module):
return residual + x, residual return residual + x, residual
class ConvLatentEncoder(nn.Module):
def __init__(self, nf):
super(ConvLatentEncoder, self).__init__()
latent_filters = [nf * 4, nf * 2, nf]
layers = []
for i in range(len(latent_filters)-1):
layers.append(nn.Sequential(
ConvGnLelu(latent_filters[i], latent_filters[i], kernel_size=1, activation=True, bias=False, norm=True),
Interpolate(2),
ConvGnLelu(latent_filters[i], latent_filters[i+1], kernel_size=1, activation=True, bias=False, norm=True)))
self.final = nn.Sequential(
ConvGnLelu(nf, nf, kernel_size=1, activation=True, bias=True, norm=True),
ConvGnLelu(nf, nf, kernel_size=1, activation=False, bias=True, norm=False))
self.layers = nn.ModuleList(layers)
def forward(self, latents):
assert len(latents) == 3
out = torch.zeros_like(latents[0])
for i in range(2):
out = out + latents[i]
out = self.layers[i](out)
out = out + latents[2]
return self.final(out)
class RRDBNetWithLatent(nn.Module): class RRDBNetWithLatent(nn.Module):
# 8-layer MLP in the vein of StyleGAN. # 8-layer MLP in the vein of StyleGAN.
def create_linear_latent_encoder(self, latent_size): def create_linear_latent_encoder(self, latent_size):
@ -110,12 +125,7 @@ class RRDBNetWithLatent(nn.Module):
# Creates a 2D latent by iterating through the provided latent_filters and doubling the # Creates a 2D latent by iterating through the provided latent_filters and doubling the
# image size each step. # image size each step.
def create_conv_latent_encoder(self, latent_filters): def create_conv_latent_encoder(self, latent_filters):
layers = [] return ConvLatentEncoder(latent_filters)
for i in range(len(latent_filters)-1):
layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i]))
layers.extend(Interpolate(2))
layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i+1]))
return nn.Sequential(*layers)
def __init__(self, def __init__(self,
in_channels, in_channels,
@ -124,14 +134,13 @@ class RRDBNetWithLatent(nn.Module):
num_blocks=23, num_blocks=23,
growth_channels=32, growth_channels=32,
blocks_per_checkpoint=4, blocks_per_checkpoint=4,
scale=4, scale=4):
latent_size=256):
super(RRDBNetWithLatent, self).__init__() super(RRDBNetWithLatent, self).__init__()
self.num_blocks = num_blocks self.num_blocks = num_blocks
self.blocks_per_checkpoint = blocks_per_checkpoint self.blocks_per_checkpoint = blocks_per_checkpoint
self.scale = scale self.scale = scale
self.in_channels = in_channels self.in_channels = in_channels
self.latent_size = latent_size self.nf = mid_channels
first_conv_stride = 1 if in_channels <= 4 else scale first_conv_stride = 1 if in_channels <= 4 else scale
first_conv_ksize = 3 if first_conv_stride == 1 else 7 first_conv_ksize = 3 if first_conv_stride == 1 else 7
first_conv_padding = 1 if first_conv_stride == 1 else 3 first_conv_padding = 1 if first_conv_stride == 1 else 3
@ -140,8 +149,7 @@ class RRDBNetWithLatent(nn.Module):
RRDBWithBypassAndLatent, RRDBWithBypassAndLatent,
num_blocks, num_blocks,
mid_channels=mid_channels, mid_channels=mid_channels,
growth_channels=growth_channels, growth_channels=growth_channels)
latent_dim=latent_size)
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
# upsample # upsample
self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
@ -150,31 +158,7 @@ class RRDBNetWithLatent(nn.Module):
self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1) self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# 8-layer MLP in the vein of StyleGAN. self.latent_encoder = self.create_conv_latent_encoder(mid_channels)
self.latent_encoder = nn.Sequential(nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(latent_size, latent_size),
nn.BatchNorm1d(latent_size),
nn.LeakyReLU(negative_slope=0.2, inplace=True))
for m in [ for m in [
self.conv_first, self.conv_body, self.conv_up1, self.conv_first, self.conv_body, self.conv_up1,
@ -185,8 +169,10 @@ class RRDBNetWithLatent(nn.Module):
def forward(self, x, latent=None, ref=None): def forward(self, x, latent=None, ref=None):
latent_was_none = latent latent_was_none = latent
if latent is None: if latent is None:
latent = torch.randn((x.shape[0], self.latent_size), dtype=torch.float, device=x.device) mults = [4, 2, 1]
latent = checkpoint(self.latent_encoder, latent) b, f, h, w = x.shape
latent = [torch.randn((b, self.nf * m, h // m, w // m), dtype=torch.float, device=x.device) for m in mults]
latent = self.latent_encoder(latent)
if latent_was_none is None: if latent_was_none is None:
self.latent_mean = torch.mean(latent).detach().cpu() self.latent_mean = torch.mean(latent).detach().cpu()
self.latent_std = torch.std(latent).detach().cpu() self.latent_std = torch.std(latent).detach().cpu()
@ -238,7 +224,7 @@ class RRDBNetWithLatent(nn.Module):
# Based heavily on the same VGG arch used for the discriminator. # Based heavily on the same VGG arch used for the discriminator.
class LatentEstimator(nn.Module): class LatentEstimator(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, in_nc, nf, latent_size=256): def __init__(self, in_nc, nf):
super(LatentEstimator, self).__init__() super(LatentEstimator, self).__init__()
# [64, 128, 128] # [64, 128, 128]
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
@ -249,57 +235,50 @@ class LatentEstimator(nn.Module):
self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
self.d1p1 = ConvGnLelu(nf * 2, nf, kernel_size=1, activation=True, norm=True, bias=True)
self.d1p2 = ConvGnLelu(nf, nf, kernel_size=1, activation=False, norm=False, bias=True)
# [128, 32, 32] # [128, 32, 32]
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
self.d2p1 = ConvGnLelu(nf * 4, nf * 2, kernel_size=1, activation=True, norm=True, bias=True)
self.d2p2 = ConvGnLelu(nf * 2, nf * 2, kernel_size=1, activation=False, norm=False, bias=True)
# [256, 16, 16] # [256, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
# [512, 8, 8] self.d3p1 = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, activation=True, norm=True, bias=True)
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) self.d3p2 = ConvGnLelu(nf * 4, nf * 4, kernel_size=1, activation=False, norm=False, bias=True)
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
final_nf = nf * 8
# activation function self.lrelu = nn.LeakyReLU(.2, inplace=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.linear1 = nn.Linear(int(final_nf * 4 * 4), latent_size*2)
self.linear2 = nn.Linear(latent_size*2, latent_size)
self.tanh = nn.Tanh() self.tanh = nn.Tanh()
def compute_body(self, x): def compute_body(self, x):
fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn1_0(self.conv1_0(x)))
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
#fea = torch.cat([fea, skip_med], dim=1)
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
o1 = self.tanh(self.d1p2(self.d1p1(fea)))
#fea = torch.cat([fea, skip_lo], dim=1)
fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
o2 = self.tanh(self.d2p2(self.d2p1(fea)))
fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
o3 = self.tanh(self.d3p2(self.d3p1(fea)))
fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) return o3, o2, o1
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
return fea
def forward(self, x): def forward(self, x):
fea = checkpoint(self.compute_body, x) fea = self.lrelu(self.conv0_0(x))
fea = fea.contiguous().view(fea.size(0), -1) fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
fea = self.linear1(fea) out = list(checkpoint(self.compute_body, fea))
out = self.tanh(self.linear2(fea)) self.latent_mean = torch.mean(out[-1])
self.latent_mean = torch.mean(out) self.latent_std = torch.std(out[-1])
self.latent_std = torch.std(out) self.latent_var = torch.var(out[-1])
self.latent_var = torch.var(out)
return out return out
def get_debug_values(self, s, n): def get_debug_values(self, s, n):

View File

@ -157,6 +157,8 @@ class Trainer:
print("Data fetch: %f" % (time() - _t)) print("Data fetch: %f" % (time() - _t))
_t = time() _t = time()
#self.tb_logger.add_graph(self.model.netsG['generator'].module, input_to_model=torch.randn((1,3,32,32), device='cuda:0'))
opt = self.opt opt = self.opt
self.current_step += 1 self.current_step += 1
#### update learning rate #### update learning rate
@ -278,7 +280,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_with_flow.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args() args = parser.parse_args()
opt = option.parse(args.opt, is_train=True) opt = option.parse(args.opt, is_train=True)