diff --git a/codes/models/archs/rrdb_with_latent.py b/codes/models/archs/rrdb_with_latent.py index 952d4271..8951dc10 100644 --- a/codes/models/archs/rrdb_with_latent.py +++ b/codes/models/archs/rrdb_with_latent.py @@ -46,13 +46,8 @@ class ResidualDenseBlock(nn.Module): class RRDBWithBypassAndLatent(nn.Module): - def __init__(self, mid_channels, growth_channels=32, latent_dim=256): + def __init__(self, mid_channels, growth_channels=32): super(RRDBWithBypassAndLatent, self).__init__() - self.latent_process = nn.Sequential(nn.Linear(latent_dim, latent_dim//2, bias=False), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Linear(latent_dim//2, mid_channels, bias=False), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Linear(mid_channels, mid_channels, bias=True)) self.latent_join = nn.Sequential(ConvGnLelu(mid_channels*2, mid_channels*2, activation=True, norm=False, bias=False), ConvGnLelu(mid_channels*2, mid_channels, activation=False, norm=False, bias=False)) self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels) @@ -63,12 +58,7 @@ class RRDBWithBypassAndLatent(nn.Module): ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False), nn.Sigmoid()) - def forward(self, x, original_latent): - b, f, h, w = x.shape - latent = self.latent_process(original_latent) - b, l = latent.shape - latent = latent.view(b, l, 1, 1) - latent = latent.repeat(1, 1, h, w) + def forward(self, x, latent): out = self.latent_join(torch.cat([x, latent], dim=1)) out = self.rdb1(out, x) out = self.rdb2(out) @@ -79,6 +69,31 @@ class RRDBWithBypassAndLatent(nn.Module): return residual + x, residual +class ConvLatentEncoder(nn.Module): + def __init__(self, nf): + super(ConvLatentEncoder, self).__init__() + latent_filters = [nf * 4, nf * 2, nf] + layers = [] + for i in range(len(latent_filters)-1): + layers.append(nn.Sequential( + ConvGnLelu(latent_filters[i], latent_filters[i], kernel_size=1, activation=True, bias=False, norm=True), + Interpolate(2), + ConvGnLelu(latent_filters[i], latent_filters[i+1], kernel_size=1, activation=True, bias=False, norm=True))) + self.final = nn.Sequential( + ConvGnLelu(nf, nf, kernel_size=1, activation=True, bias=True, norm=True), + ConvGnLelu(nf, nf, kernel_size=1, activation=False, bias=True, norm=False)) + self.layers = nn.ModuleList(layers) + + def forward(self, latents): + assert len(latents) == 3 + out = torch.zeros_like(latents[0]) + for i in range(2): + out = out + latents[i] + out = self.layers[i](out) + out = out + latents[2] + return self.final(out) + + class RRDBNetWithLatent(nn.Module): # 8-layer MLP in the vein of StyleGAN. def create_linear_latent_encoder(self, latent_size): @@ -110,12 +125,7 @@ class RRDBNetWithLatent(nn.Module): # Creates a 2D latent by iterating through the provided latent_filters and doubling the # image size each step. def create_conv_latent_encoder(self, latent_filters): - layers = [] - for i in range(len(latent_filters)-1): - layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i])) - layers.extend(Interpolate(2)) - layers.extend(ConvGnLelu(latent_filters[i], latent_filters[i+1])) - return nn.Sequential(*layers) + return ConvLatentEncoder(latent_filters) def __init__(self, in_channels, @@ -124,14 +134,13 @@ class RRDBNetWithLatent(nn.Module): num_blocks=23, growth_channels=32, blocks_per_checkpoint=4, - scale=4, - latent_size=256): + scale=4): super(RRDBNetWithLatent, self).__init__() self.num_blocks = num_blocks self.blocks_per_checkpoint = blocks_per_checkpoint self.scale = scale self.in_channels = in_channels - self.latent_size = latent_size + self.nf = mid_channels first_conv_stride = 1 if in_channels <= 4 else scale first_conv_ksize = 3 if first_conv_stride == 1 else 7 first_conv_padding = 1 if first_conv_stride == 1 else 3 @@ -140,8 +149,7 @@ class RRDBNetWithLatent(nn.Module): RRDBWithBypassAndLatent, num_blocks, mid_channels=mid_channels, - growth_channels=growth_channels, - latent_dim=latent_size) + growth_channels=growth_channels) self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) # upsample self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) @@ -150,31 +158,7 @@ class RRDBNetWithLatent(nn.Module): self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - # 8-layer MLP in the vein of StyleGAN. - self.latent_encoder = nn.Sequential(nn.Linear(latent_size, latent_size), - nn.BatchNorm1d(latent_size), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Linear(latent_size, latent_size), - nn.BatchNorm1d(latent_size), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Linear(latent_size, latent_size), - nn.BatchNorm1d(latent_size), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Linear(latent_size, latent_size), - nn.BatchNorm1d(latent_size), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Linear(latent_size, latent_size), - nn.BatchNorm1d(latent_size), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Linear(latent_size, latent_size), - nn.BatchNorm1d(latent_size), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Linear(latent_size, latent_size), - nn.BatchNorm1d(latent_size), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Linear(latent_size, latent_size), - nn.BatchNorm1d(latent_size), - nn.LeakyReLU(negative_slope=0.2, inplace=True)) + self.latent_encoder = self.create_conv_latent_encoder(mid_channels) for m in [ self.conv_first, self.conv_body, self.conv_up1, @@ -185,8 +169,10 @@ class RRDBNetWithLatent(nn.Module): def forward(self, x, latent=None, ref=None): latent_was_none = latent if latent is None: - latent = torch.randn((x.shape[0], self.latent_size), dtype=torch.float, device=x.device) - latent = checkpoint(self.latent_encoder, latent) + mults = [4, 2, 1] + b, f, h, w = x.shape + latent = [torch.randn((b, self.nf * m, h // m, w // m), dtype=torch.float, device=x.device) for m in mults] + latent = self.latent_encoder(latent) if latent_was_none is None: self.latent_mean = torch.mean(latent).detach().cpu() self.latent_std = torch.std(latent).detach().cpu() @@ -238,7 +224,7 @@ class RRDBNetWithLatent(nn.Module): # Based heavily on the same VGG arch used for the discriminator. class LatentEstimator(nn.Module): # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. - def __init__(self, in_nc, nf, latent_size=256): + def __init__(self, in_nc, nf): super(LatentEstimator, self).__init__() # [64, 128, 128] self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) @@ -249,57 +235,50 @@ class LatentEstimator(nn.Module): self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) + self.d1p1 = ConvGnLelu(nf * 2, nf, kernel_size=1, activation=True, norm=True, bias=True) + self.d1p2 = ConvGnLelu(nf, nf, kernel_size=1, activation=False, norm=False, bias=True) + # [128, 32, 32] self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) + self.d2p1 = ConvGnLelu(nf * 4, nf * 2, kernel_size=1, activation=True, norm=True, bias=True) + self.d2p2 = ConvGnLelu(nf * 2, nf * 2, kernel_size=1, activation=False, norm=False, bias=True) + # [256, 16, 16] self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) - # [512, 8, 8] - self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) - self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) - self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) - self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) - final_nf = nf * 8 + self.d3p1 = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, activation=True, norm=True, bias=True) + self.d3p2 = ConvGnLelu(nf * 4, nf * 4, kernel_size=1, activation=False, norm=False, bias=True) - # activation function - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - self.linear1 = nn.Linear(int(final_nf * 4 * 4), latent_size*2) - self.linear2 = nn.Linear(latent_size*2, latent_size) + self.lrelu = nn.LeakyReLU(.2, inplace=True) self.tanh = nn.Tanh() def compute_body(self, x): - fea = self.lrelu(self.conv0_0(x)) - fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) - - #fea = torch.cat([fea, skip_med], dim=1) - fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) + fea = self.lrelu(self.bn1_0(self.conv1_0(x))) fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) + o1 = self.tanh(self.d1p2(self.d1p1(fea))) - #fea = torch.cat([fea, skip_lo], dim=1) fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) + o2 = self.tanh(self.d2p2(self.d2p1(fea))) fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) + o3 = self.tanh(self.d3p2(self.d3p1(fea))) - fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) - fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) - return fea + return o3, o2, o1 def forward(self, x): - fea = checkpoint(self.compute_body, x) - fea = fea.contiguous().view(fea.size(0), -1) - fea = self.linear1(fea) - out = self.tanh(self.linear2(fea)) - self.latent_mean = torch.mean(out) - self.latent_std = torch.std(out) - self.latent_var = torch.var(out) + fea = self.lrelu(self.conv0_0(x)) + fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) + out = list(checkpoint(self.compute_body, fea)) + self.latent_mean = torch.mean(out[-1]) + self.latent_std = torch.std(out[-1]) + self.latent_var = torch.var(out[-1]) return out def get_debug_values(self, s, n): diff --git a/codes/train2.py b/codes/train2.py index d068c125..cd608f97 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -157,6 +157,8 @@ class Trainer: print("Data fetch: %f" % (time() - _t)) _t = time() + #self.tb_logger.add_graph(self.model.netsG['generator'].module, input_to_model=torch.randn((1,3,32,32), device='cuda:0')) + opt = self.opt self.current_step += 1 #### update learning rate @@ -278,7 +280,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_with_flow.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)