diff --git a/codes/models/archs/rrdb_with_latent.py b/codes/models/archs/rrdb_with_latent.py new file mode 100644 index 00000000..7e988129 --- /dev/null +++ b/codes/models/archs/rrdb_with_latent.py @@ -0,0 +1,240 @@ +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +from torch.utils.checkpoint import checkpoint_sequential + +from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu +from utils.util import checkpoint + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + mid_channels (int): Channel number of intermediate features. + growth_channels (int): Channels for each growth. + """ + + def __init__(self, mid_channels=64, growth_channels=32): + super(ResidualDenseBlock, self).__init__() + for i in range(5): + out_channels = mid_channels if i == 4 else growth_channels + self.add_module( + f'conv{i+1}', + nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3, + 1, 1)) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + for i in range(5): + default_init_weights(getattr(self, f'conv{i+1}'), 0.1) + + + def forward(self, x, identity=None): + if identity is None: + identity = x + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + identity + + +class RRDBWithBypassAndLatent(nn.Module): + def __init__(self, mid_channels, growth_channels=32, latent_dim=256): + super(RRDBWithBypassAndLatent, self).__init__() + self.latent_process = nn.Sequential(nn.Linear(latent_dim, latent_dim//2, bias=False), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_dim//2, mid_channels, bias=False), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(mid_channels, mid_channels, bias=True)) + self.latent_join = nn.Sequential(ConvGnLelu(mid_channels*2, mid_channels*2, activation=True, norm=False, bias=False), + ConvGnLelu(mid_channels*2, mid_channels, activation=False, norm=False, bias=False)) + self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels) + self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels) + self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels) + self.bypass = nn.Sequential(ConvGnSilu(mid_channels*2, mid_channels, kernel_size=3, bias=True, activation=True, norm=True), + ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False), + ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False), + nn.Sigmoid()) + + def forward(self, x, original_latent): + b, f, h, w = x.shape + latent = self.latent_process(original_latent) + b, l = latent.shape + latent = latent.view(b, l, 1, 1) + latent = latent.repeat(1, 1, h, w) + out = self.latent_join(torch.cat([x, latent], dim=1)) + out = self.rdb1(out, x) + out = self.rdb2(out) + out = self.rdb3(out) + bypass = self.bypass(torch.cat([x, out], dim=1)) + self.bypass_map = bypass.detach().clone() + return out * 0.2 * bypass + x + + +class RRDBNetWithLatent(nn.Module): + def __init__(self, + in_channels, + out_channels, + mid_channels=64, + num_blocks=23, + growth_channels=32, + blocks_per_checkpoint=4, + scale=4, + latent_size=256): + super(RRDBNetWithLatent, self).__init__() + self.num_blocks = num_blocks + self.blocks_per_checkpoint = blocks_per_checkpoint + self.scale = scale + self.in_channels = in_channels + self.latent_size = latent_size + first_conv_stride = 1 if in_channels <= 4 else scale + first_conv_ksize = 3 if first_conv_stride == 1 else 7 + first_conv_padding = 1 if first_conv_stride == 1 else 3 + self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding) + self.body = make_layer( + RRDBWithBypassAndLatent, + num_blocks, + mid_channels=mid_channels, + growth_channels=growth_channels, + latent_dim=latent_size) + self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) + self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) + self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) + self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # 8-layer MLP in the vein of StyleGAN. + self.latent_encoder = nn.Sequential(nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Linear(latent_size, latent_size), + nn.BatchNorm1d(latent_size), + nn.LeakyReLU(negative_slope=0.2, inplace=True)) + + for m in [ + self.conv_first, self.conv_body, self.conv_up1, + self.conv_up2, self.conv_hr, self.conv_last + ]: + default_init_weights(m, 0.1) + + def forward(self, x, latent=None, ref=None): + if latent is None: + latent = torch.randn((x.shape[0], self.latent_size), dtype=torch.float, device=x.device) + latent = checkpoint(self.latent_encoder, latent) + if self.in_channels > 4: + x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic") + if ref is None: + ref = torch.zeros_like(x_lg) + x_lg = torch.cat([x_lg, ref], dim=1) + else: + x_lg = x + feat = self.conv_first(x_lg) + body_feat = feat + for bl in self.body: + body_feat = checkpoint(bl, body_feat, latent) + body_feat = self.conv_body(body_feat) + feat = feat + body_feat + # upsample + feat = self.lrelu( + self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + if self.scale == 4: + feat = self.lrelu( + self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + else: + feat = self.lrelu(self.conv_up2(feat)) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out + + def visual_dbg(self, step, path): + for i, bm in enumerate(self.body): + torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1))) + + +# Based heavily on the same VGG arch used for the discriminator. +class LatentEstimator(nn.Module): + # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. + def __init__(self, in_nc, nf, latent_size=256): + super(LatentEstimator, self).__init__() + # [64, 128, 128] + self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) + self.bn0_1 = nn.BatchNorm2d(nf, affine=True) + # [64, 64, 64] + self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) + self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) + self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) + self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) + # [128, 32, 32] + self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) + self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) + self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) + self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) + # [256, 16, 16] + self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) + self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) + # [512, 8, 8] + self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) + self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) + final_nf = nf * 8 + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + self.linear1 = nn.Linear(int(final_nf * 4 * 4), latent_size*2) + self.linear2 = nn.Linear(latent_size*2, latent_size) + + def compute_body(self, x): + fea = self.lrelu(self.conv0_0(x)) + fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) + + #fea = torch.cat([fea, skip_med], dim=1) + fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) + fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) + + #fea = torch.cat([fea, skip_lo], dim=1) + fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) + fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) + + fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) + fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) + + fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) + fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) + return fea + + def forward(self, x): + fea = checkpoint(self.compute_body, x) + fea = fea.contiguous().view(fea.size(0), -1) + fea = self.linear1(fea) + out = self.linear2(fea) + return out \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index dc103b74..64f84d8e 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -20,6 +20,7 @@ import models.archs.rcan as rcan import models.archs.ChainedEmbeddingGen as chained from models.archs import srg2_classic from models.archs.pyramid_arch import BasicResamplingFlowNet +from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent from models.archs.teco_resgen import TecoGen logger = logging.getLogger('base') @@ -118,6 +119,12 @@ def define_G(opt, net_key='network_G', scale=None): netG = TecoGen(opt_net['nf'], opt_net['scale']) elif which_model == "basic_resampling_flow_predictor": netG = BasicResamplingFlowNet(opt_net['nf'], resample_scale=opt_net['resample_scale']) + elif which_model == "rrdb_with_latent": + netG = RRDBNetWithLatent(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], + mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], + blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale']) + elif which_model == "latent_estimator": + netG = LatentEstimator(in_nc=3, nf=opt_net['nf']) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 3767b89b..1eb02939 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -26,7 +26,7 @@ class ConfigurableStep(Module): self.optimizers = None self.scaler = GradScaler(enabled=self.opt['fp16']) self.grads_generated = False - self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else 0 + self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999 self.injectors = [] if 'injectors' in self.step_opt.keys():