diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index e0c99338..1a0a52e6 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -7,6 +7,7 @@ import torchvision from torch.utils.checkpoint import checkpoint_sequential from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu +from utils.util import checkpoint class ResidualDenseBlock(nn.Module): @@ -280,3 +281,79 @@ class RRDBNet(nn.Module): torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1))) + +class DiscRDB(nn.Module): + def __init__(self, mid_channels=64, growth_channels=32): + super(DiscRDB, self).__init__() + for i in range(5): + out_channels = mid_channels if i == 4 else growth_channels + actnorm = i != 5 + self.add_module( + f'conv{i+1}', + ConvGnLelu(mid_channels + i * growth_channels, out_channels, kernel_size=3, norm=actnorm, activation=actnorm, bias=True) + ) + self.lrelu = nn.LeakyReLU(negative_slope=.2) + for i in range(5): + default_init_weights(getattr(self, f'conv{i+1}'), 1) + + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(torch.cat((x, x1), 1)) + x3 = self.conv3(torch.cat((x, x1, x2), 1)) + x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return self.lrelu(x5 + x) + + +class DiscRRDB(nn.Module): + def __init__(self, mid_channels, growth_channels=32): + super(DiscRRDB, self).__init__() + self.rdb1 = DiscRDB(mid_channels, growth_channels) + self.rdb2 = DiscRDB(mid_channels, growth_channels) + self.rdb3 = DiscRDB(mid_channels, growth_channels) + self.gn = nn.GroupNorm(num_groups=8, num_channels=mid_channels) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + return self.gn(out + x) + + +class RRDBDiscriminator(nn.Module): + def __init__(self, + in_channels, + mid_channels=64, + num_blocks=23, + growth_channels=32, + blocks_per_checkpoint=1 + ): + super(RRDBDiscriminator, self).__init__() + self.num_blocks = num_blocks + self.blocks_per_checkpoint = blocks_per_checkpoint + self.in_channels = in_channels + self.conv_first = ConvGnLelu(in_channels, mid_channels, 3, stride=4, activation=False, norm=False, bias=True) + self.body = make_layer( + DiscRRDB, + num_blocks, + mid_channels=mid_channels, + growth_channels=growth_channels) + self.tail = nn.Sequential( + ConvGnLelu(mid_channels, mid_channels // 2, kernel_size=1, activation=True, norm=False, bias=True), + ConvGnLelu(mid_channels // 2, mid_channels // 4, kernel_size=1, activation=True, norm=False, bias=True), + ConvGnLelu(mid_channels // 4, 1, kernel_size=1, activation=False, norm=False, bias=True) + ) + self.pred_ = None + + def forward(self, x): + feat = self.conv_first(x) + feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat) + pred = checkpoint(self.tail, feat) + self.pred_ = pred.detach().clone() + return pred + + def visual_dbg(self, step, path): + if self.pred_ is not None: + self.pred_ = F.sigmoid(self.pred_) + torchvision.utils.save_image(self.pred_.cpu().float(), os.path.join(path, "%i_predictions.png" % (step,))) diff --git a/codes/models/archs/lambda_rrdb.py b/codes/models/archs/lambda_rrdb.py index 9493a6b9..19ed468e 100644 --- a/codes/models/archs/lambda_rrdb.py +++ b/codes/models/archs/lambda_rrdb.py @@ -4,6 +4,7 @@ from lambda_networks import LambdaLayer from torch.nn import GroupNorm from models.archs.RRDBNet_arch import ResidualDenseBlock +from models.archs.arch_util import ConvGnLelu class LambdaRRDB(nn.Module): @@ -18,13 +19,15 @@ class LambdaRRDB(nn.Module): def __init__(self, mid_channels, growth_channels=32, reduce_to=None): super(LambdaRRDB, self).__init__() - self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1) - self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1) if reduce_to is None: reduce_to = mid_channels - self.lam = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4) - self.gn = GroupNorm(num_groups=8, num_channels=mid_channels) - self.scale = nn.Parameter(torch.full((1,), 1/256)) + self.lam1 = LambdaLayer(dim=mid_channels, dim_out=mid_channels, r=23, dim_k=16, heads=4, dim_u=4) + self.gn1 = GroupNorm(num_groups=8, num_channels=mid_channels) + self.lam2 = LambdaLayer(dim=mid_channels, dim_out=mid_channels, r=23, dim_k=16, heads=4, dim_u=4) + self.gn2 = GroupNorm(num_groups=8, num_channels=mid_channels) + self.lam3 = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4) + self.gn3 = GroupNorm(num_groups=8, num_channels=mid_channels) + self.conv = ConvGnLelu(reduce_to, reduce_to, kernel_size=1, bias=True, norm=False, activation=False, weight_init_factor=.1) def forward(self, x): """Forward function. @@ -35,8 +38,10 @@ class LambdaRRDB(nn.Module): Returns: Tensor: Forward results. """ - out = self.rdb1(x) - out = self.rdb2(out) - out = self.lam(out) - out = self.gn(out) - return out * self.scale + x \ No newline at end of file + out = self.lam1(x) + out = self.gn1(out) + out = self.lam2(out) + out = self.gn1(out) + out = self.lam3(out) + out = self.gn3(out) + return self.conv(out) * .2 + x \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index e4a76f73..d7f0f1f5 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -39,10 +39,10 @@ def define_G(opt, opt_net, scale=None): nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif 'RRDBNet' in which_model: if which_model == 'RRDBNetBypass': + block = RRDBNet_arch.RRDBWithBypass + elif which_model == 'RRDBNetLambda': from models.archs.lambda_rrdb import LambdaRRDB block = LambdaRRDB - elif which_model == 'RRDBNetLambda': - block = RRDBNet_arch.RRDBWithBypass else: block = RRDBNet_arch.RRDB additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not' @@ -226,6 +226,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False): elif which_model == "stylegan2_unet": disc = stylegan2_unet.StyleGan2UnetDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc']) netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) + elif which_model == "rrdb_disc": + netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD