RRDB disc work

This commit is contained in:
James Betker 2020-11-27 12:03:08 -07:00
parent 6de4dabb73
commit 4ab49b0d69
3 changed files with 96 additions and 12 deletions

View File

@ -7,6 +7,7 @@ import torchvision
from torch.utils.checkpoint import checkpoint_sequential
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
from utils.util import checkpoint
class ResidualDenseBlock(nn.Module):
@ -280,3 +281,79 @@ class RRDBNet(nn.Module):
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
class DiscRDB(nn.Module):
def __init__(self, mid_channels=64, growth_channels=32):
super(DiscRDB, self).__init__()
for i in range(5):
out_channels = mid_channels if i == 4 else growth_channels
actnorm = i != 5
self.add_module(
f'conv{i+1}',
ConvGnLelu(mid_channels + i * growth_channels, out_channels, kernel_size=3, norm=actnorm, activation=actnorm, bias=True)
)
self.lrelu = nn.LeakyReLU(negative_slope=.2)
for i in range(5):
default_init_weights(getattr(self, f'conv{i+1}'), 1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return self.lrelu(x5 + x)
class DiscRRDB(nn.Module):
def __init__(self, mid_channels, growth_channels=32):
super(DiscRRDB, self).__init__()
self.rdb1 = DiscRDB(mid_channels, growth_channels)
self.rdb2 = DiscRDB(mid_channels, growth_channels)
self.rdb3 = DiscRDB(mid_channels, growth_channels)
self.gn = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
return self.gn(out + x)
class RRDBDiscriminator(nn.Module):
def __init__(self,
in_channels,
mid_channels=64,
num_blocks=23,
growth_channels=32,
blocks_per_checkpoint=1
):
super(RRDBDiscriminator, self).__init__()
self.num_blocks = num_blocks
self.blocks_per_checkpoint = blocks_per_checkpoint
self.in_channels = in_channels
self.conv_first = ConvGnLelu(in_channels, mid_channels, 3, stride=4, activation=False, norm=False, bias=True)
self.body = make_layer(
DiscRRDB,
num_blocks,
mid_channels=mid_channels,
growth_channels=growth_channels)
self.tail = nn.Sequential(
ConvGnLelu(mid_channels, mid_channels // 2, kernel_size=1, activation=True, norm=False, bias=True),
ConvGnLelu(mid_channels // 2, mid_channels // 4, kernel_size=1, activation=True, norm=False, bias=True),
ConvGnLelu(mid_channels // 4, 1, kernel_size=1, activation=False, norm=False, bias=True)
)
self.pred_ = None
def forward(self, x):
feat = self.conv_first(x)
feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
pred = checkpoint(self.tail, feat)
self.pred_ = pred.detach().clone()
return pred
def visual_dbg(self, step, path):
if self.pred_ is not None:
self.pred_ = F.sigmoid(self.pred_)
torchvision.utils.save_image(self.pred_.cpu().float(), os.path.join(path, "%i_predictions.png" % (step,)))

View File

@ -4,6 +4,7 @@ from lambda_networks import LambdaLayer
from torch.nn import GroupNorm
from models.archs.RRDBNet_arch import ResidualDenseBlock
from models.archs.arch_util import ConvGnLelu
class LambdaRRDB(nn.Module):
@ -18,13 +19,15 @@ class LambdaRRDB(nn.Module):
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
super(LambdaRRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1)
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1)
if reduce_to is None:
reduce_to = mid_channels
self.lam = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4)
self.gn = GroupNorm(num_groups=8, num_channels=mid_channels)
self.scale = nn.Parameter(torch.full((1,), 1/256))
self.lam1 = LambdaLayer(dim=mid_channels, dim_out=mid_channels, r=23, dim_k=16, heads=4, dim_u=4)
self.gn1 = GroupNorm(num_groups=8, num_channels=mid_channels)
self.lam2 = LambdaLayer(dim=mid_channels, dim_out=mid_channels, r=23, dim_k=16, heads=4, dim_u=4)
self.gn2 = GroupNorm(num_groups=8, num_channels=mid_channels)
self.lam3 = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4)
self.gn3 = GroupNorm(num_groups=8, num_channels=mid_channels)
self.conv = ConvGnLelu(reduce_to, reduce_to, kernel_size=1, bias=True, norm=False, activation=False, weight_init_factor=.1)
def forward(self, x):
"""Forward function.
@ -35,8 +38,10 @@ class LambdaRRDB(nn.Module):
Returns:
Tensor: Forward results.
"""
out = self.rdb1(x)
out = self.rdb2(out)
out = self.lam(out)
out = self.gn(out)
return out * self.scale + x
out = self.lam1(x)
out = self.gn1(out)
out = self.lam2(out)
out = self.gn1(out)
out = self.lam3(out)
out = self.gn3(out)
return self.conv(out) * .2 + x

View File

@ -39,10 +39,10 @@ def define_G(opt, opt_net, scale=None):
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
elif 'RRDBNet' in which_model:
if which_model == 'RRDBNetBypass':
block = RRDBNet_arch.RRDBWithBypass
elif which_model == 'RRDBNetLambda':
from models.archs.lambda_rrdb import LambdaRRDB
block = LambdaRRDB
elif which_model == 'RRDBNetLambda':
block = RRDBNet_arch.RRDBWithBypass
else:
block = RRDBNet_arch.RRDB
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
@ -226,6 +226,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
elif which_model == "stylegan2_unet":
disc = stylegan2_unet.StyleGan2UnetDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'])
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
elif which_model == "rrdb_disc":
netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3)
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD