RRDB disc work
This commit is contained in:
parent
6de4dabb73
commit
4ab49b0d69
|
@ -7,6 +7,7 @@ import torchvision
|
|||
from torch.utils.checkpoint import checkpoint_sequential
|
||||
|
||||
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
class ResidualDenseBlock(nn.Module):
|
||||
|
@ -280,3 +281,79 @@ class RRDBNet(nn.Module):
|
|||
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||
|
||||
|
||||
|
||||
class DiscRDB(nn.Module):
|
||||
def __init__(self, mid_channels=64, growth_channels=32):
|
||||
super(DiscRDB, self).__init__()
|
||||
for i in range(5):
|
||||
out_channels = mid_channels if i == 4 else growth_channels
|
||||
actnorm = i != 5
|
||||
self.add_module(
|
||||
f'conv{i+1}',
|
||||
ConvGnLelu(mid_channels + i * growth_channels, out_channels, kernel_size=3, norm=actnorm, activation=actnorm, bias=True)
|
||||
)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=.2)
|
||||
for i in range(5):
|
||||
default_init_weights(getattr(self, f'conv{i+1}'), 1)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return self.lrelu(x5 + x)
|
||||
|
||||
|
||||
class DiscRRDB(nn.Module):
|
||||
def __init__(self, mid_channels, growth_channels=32):
|
||||
super(DiscRRDB, self).__init__()
|
||||
self.rdb1 = DiscRDB(mid_channels, growth_channels)
|
||||
self.rdb2 = DiscRDB(mid_channels, growth_channels)
|
||||
self.rdb3 = DiscRDB(mid_channels, growth_channels)
|
||||
self.gn = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.rdb1(x)
|
||||
out = self.rdb2(out)
|
||||
out = self.rdb3(out)
|
||||
return self.gn(out + x)
|
||||
|
||||
|
||||
class RRDBDiscriminator(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
mid_channels=64,
|
||||
num_blocks=23,
|
||||
growth_channels=32,
|
||||
blocks_per_checkpoint=1
|
||||
):
|
||||
super(RRDBDiscriminator, self).__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.blocks_per_checkpoint = blocks_per_checkpoint
|
||||
self.in_channels = in_channels
|
||||
self.conv_first = ConvGnLelu(in_channels, mid_channels, 3, stride=4, activation=False, norm=False, bias=True)
|
||||
self.body = make_layer(
|
||||
DiscRRDB,
|
||||
num_blocks,
|
||||
mid_channels=mid_channels,
|
||||
growth_channels=growth_channels)
|
||||
self.tail = nn.Sequential(
|
||||
ConvGnLelu(mid_channels, mid_channels // 2, kernel_size=1, activation=True, norm=False, bias=True),
|
||||
ConvGnLelu(mid_channels // 2, mid_channels // 4, kernel_size=1, activation=True, norm=False, bias=True),
|
||||
ConvGnLelu(mid_channels // 4, 1, kernel_size=1, activation=False, norm=False, bias=True)
|
||||
)
|
||||
self.pred_ = None
|
||||
|
||||
def forward(self, x):
|
||||
feat = self.conv_first(x)
|
||||
feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
|
||||
pred = checkpoint(self.tail, feat)
|
||||
self.pred_ = pred.detach().clone()
|
||||
return pred
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
if self.pred_ is not None:
|
||||
self.pred_ = F.sigmoid(self.pred_)
|
||||
torchvision.utils.save_image(self.pred_.cpu().float(), os.path.join(path, "%i_predictions.png" % (step,)))
|
||||
|
|
|
@ -4,6 +4,7 @@ from lambda_networks import LambdaLayer
|
|||
from torch.nn import GroupNorm
|
||||
|
||||
from models.archs.RRDBNet_arch import ResidualDenseBlock
|
||||
from models.archs.arch_util import ConvGnLelu
|
||||
|
||||
|
||||
class LambdaRRDB(nn.Module):
|
||||
|
@ -18,13 +19,15 @@ class LambdaRRDB(nn.Module):
|
|||
|
||||
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
|
||||
super(LambdaRRDB, self).__init__()
|
||||
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1)
|
||||
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1)
|
||||
if reduce_to is None:
|
||||
reduce_to = mid_channels
|
||||
self.lam = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4)
|
||||
self.gn = GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||
self.scale = nn.Parameter(torch.full((1,), 1/256))
|
||||
self.lam1 = LambdaLayer(dim=mid_channels, dim_out=mid_channels, r=23, dim_k=16, heads=4, dim_u=4)
|
||||
self.gn1 = GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||
self.lam2 = LambdaLayer(dim=mid_channels, dim_out=mid_channels, r=23, dim_k=16, heads=4, dim_u=4)
|
||||
self.gn2 = GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||
self.lam3 = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4)
|
||||
self.gn3 = GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||
self.conv = ConvGnLelu(reduce_to, reduce_to, kernel_size=1, bias=True, norm=False, activation=False, weight_init_factor=.1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function.
|
||||
|
@ -35,8 +38,10 @@ class LambdaRRDB(nn.Module):
|
|||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
out = self.rdb1(x)
|
||||
out = self.rdb2(out)
|
||||
out = self.lam(out)
|
||||
out = self.gn(out)
|
||||
return out * self.scale + x
|
||||
out = self.lam1(x)
|
||||
out = self.gn1(out)
|
||||
out = self.lam2(out)
|
||||
out = self.gn1(out)
|
||||
out = self.lam3(out)
|
||||
out = self.gn3(out)
|
||||
return self.conv(out) * .2 + x
|
|
@ -39,10 +39,10 @@ def define_G(opt, opt_net, scale=None):
|
|||
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||
elif 'RRDBNet' in which_model:
|
||||
if which_model == 'RRDBNetBypass':
|
||||
block = RRDBNet_arch.RRDBWithBypass
|
||||
elif which_model == 'RRDBNetLambda':
|
||||
from models.archs.lambda_rrdb import LambdaRRDB
|
||||
block = LambdaRRDB
|
||||
elif which_model == 'RRDBNetLambda':
|
||||
block = RRDBNet_arch.RRDBWithBypass
|
||||
else:
|
||||
block = RRDBNet_arch.RRDB
|
||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
|
||||
|
@ -226,6 +226,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
|||
elif which_model == "stylegan2_unet":
|
||||
disc = stylegan2_unet.StyleGan2UnetDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'])
|
||||
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||
elif which_model == "rrdb_disc":
|
||||
netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3)
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
return netD
|
||||
|
|
Loading…
Reference in New Issue
Block a user