Add serveral new RRDB architectures

This commit is contained in:
James Betker 2020-06-09 13:28:55 -06:00
parent 296135ec18
commit 12e8fad079
2 changed files with 174 additions and 37 deletions

View File

@ -61,8 +61,6 @@ class RRDB(nn.Module):
return out * 0.2 + x
class AttentiveRRDB(RRDB):
counter = 0
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1):
super(RRDB, self).__init__()
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
@ -72,8 +70,6 @@ class AttentiveRRDB(RRDB):
self.final_temperature_step = final_temperature_step
self.running_mean = 0
self.running_count = 0
self.counter = AttentiveRRDB.counter
AttentiveRRDB.counter += 1
def set_temperature(self, temp):
self.RDB1.switcher.set_attention_temperature(temp)
@ -93,37 +89,28 @@ class AttentiveRRDB(RRDB):
return out * 0.2 + x
def get_debug_values(self, step):
def get_debug_values(self, step, prefix):
# Take the chance to update the temperature here.
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
self.set_temperature(temp)
# Intentionally overwrite attention_temperature from other RRDB blocks; these should be synced.
val = {"RRDB_%i_attention_mean" % (self.counter,): self.running_mean / self.running_count,
val = {"%s_attention_mean" % (prefix,): self.running_mean / self.running_count,
"attention_temperature": temp}
self.running_count = 0
self.running_mean = 0
return val
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1,
rrdb_block_f=None):
super(RRDBNet, self).__init__()
# This module performs the majority of the processing done by RRDBNet. It just doesn't have the upsampling at the end.
class RRDBTrunk(nn.Module):
def __init__(self, nf_in, nf_out, nb, gc=32, initial_stride=1, rrdb_block_f=None):
super(RRDBTrunk, self).__init__()
if rrdb_block_f is None:
rrdb_block_f = functools.partial(RRDB, nf=nf, gc=gc)
rrdb_block_f = functools.partial(RRDB, nf=nf_out, gc=gc)
self.scale = scale
self.conv_first = nn.Conv2d(in_nc, nf, 7, initial_stride, padding=3, bias=True)
self.conv_first = nn.Conv2d(nf_in, nf_out, 7, initial_stride, padding=3, bias=True)
self.RRDB_trunk, self.rrdb_layers = arch_util.make_layer(rrdb_block_f, nb, True)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
self.upconv2 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.trunk_conv = nn.Conv2d(nf_out, nf_out, 3, 1, 1, bias=True)
# Sets the softmax temperature of each RRDB layer. Only works if you are using attentive
# convolutions.
@ -135,6 +122,58 @@ class RRDBNet(nn.Module):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
return fea
def get_debug_values(self, step, prefix):
val = {}
i = 0
for block in self.RRDB_trunk._modules.values():
if hasattr(block, "get_debug_values"):
val.update(block.get_debug_values(step, "%s_rdb_%i" % (prefix, i)))
i += 1
return val
# Adds some base methods that all RRDB* classes will use.
class RRDBBase(nn.Module):
def __init__(self):
super(RRDBBase, self).__init__()
# Sets the softmax temperature of each RRDB layer. Only works if you are using attentive
# convolutions.
def set_temperature(self, temp):
for trunk in self.trunks:
for layer in trunk.rrdb_layers:
layer.set_temperature(temp)
def get_debug_values(self, step):
val = {}
for i, trunk in enumerate(self.trunks):
for j, block in enumerate(trunk.RRDB_trunk._modules.values()):
if hasattr(block, "get_debug_values"):
val.update(block.get_debug_values(step, "trunk_%i_block_%i" % (i, j)))
return val
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
class RRDBNet(RRDBBase):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1,
rrdb_block_f=None):
super(RRDBNet, self).__init__()
# Trunk - does actual processing.
self.trunk = RRDBTrunk(in_nc, nf, nb, gc, initial_stride, rrdb_block_f)
self.trunks = [self.trunk]
# Upsampling
self.scale = scale
self.upconv1 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
self.upconv2 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.trunk(x)
if self.scale >= 2:
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
@ -146,15 +185,16 @@ class RRDBNet(nn.Module):
return (out,)
def get_debug_values(self, step):
val = {}
for block in self.RRDB_trunk._modules.values():
if hasattr(block, "get_debug_values"):
val.update(block.get_debug_values(step))
return val
def load_state_dict(self, state_dict, strict=True):
# The parameters in self.trunk used to be in this class. To support loading legacy saves, restore them.
t_state = self.trunk.state_dict()
for k in t_state.keys():
state_dict["trunk.%s" % (k,)] = state_dict.pop(k)
super(RRDBNet, self).load_state_dict(state_dict, strict)
# Variant of RRDBNet that is "assisted" by an external pretrained image classifier whose
# intermediate layers have been splayed out, pixel-shuffled, and fed back in.
# TODO: Convert to use new RRDBBase hierarchy.
class AssistedRRDBNet(nn.Module):
# in_nc=number of input channels.
# out_nc=number of output channels.
@ -171,10 +211,9 @@ class AssistedRRDBNet(nn.Module):
# Set-up the assist-net, which should do feature extraction for us.
self.assistnet = torchvision.models.wide_resnet50_2(pretrained=True)
self.set_enable_assistnet_training(False)
assist_nf = [2, 4, 8, 16] # Fixed for resnet. Re-evaluate if using other networks.
self.assist1 = RRDB(nf + assist_nf[0], gc)
self.assist2 = RRDB(nf + sum(assist_nf[:2]), gc)
self.assist3 = RRDB(nf + sum(assist_nf[:3]), gc)
assist_nf = [4, 8, 16] # Fixed for resnet. Re-evaluate if using other networks.
self.assist2 = RRDB(nf + assist_nf[0], gc)
self.assist3 = RRDB(nf + sum(assist_nf[:2]), gc)
self.assist4 = RRDB(nf + sum(assist_nf), gc)
nf = nf + sum(assist_nf)
@ -195,6 +234,11 @@ class AssistedRRDBNet(nn.Module):
p.requires_grad = en
def res_extract(self, x):
# Width and height must be factors of 16 to use this architecture. Check that here.
(b, f, w, h) = x.shape
assert w % 16 == 0
assert h % 16 == 0
x = self.assistnet.conv1(x)
x = self.assistnet.bn1(x)
x = self.assistnet.relu(x)
@ -206,16 +250,13 @@ class AssistedRRDBNet(nn.Module):
l2 = F.pixel_shuffle(x, 8)
x = self.assistnet.layer3(x)
l3 = F.pixel_shuffle(x, 16)
x = self.assistnet.layer4(x)
l4 = F.pixel_shuffle(x, 32)
return l1, l2, l3, l4
return l1, l2, l3
def forward(self, x):
# Invoke the assistant net first.
l1, l2, l3, l4 = self.res_extract(x)
l1, l2, l3 = self.res_extract(x)
fea = self.conv_first(x)
fea = self.assist1(torch.cat([fea, l4], dim=1))
fea = self.assist2(torch.cat([fea, l3], dim=1))
fea = self.assist3(torch.cat([fea, l2], dim=1))
fea = self.assist4(torch.cat([fea, l1], dim=1))
@ -231,4 +272,85 @@ class AssistedRRDBNet(nn.Module):
fea = self.lrelu(self.upconv2(fea))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return (out,)
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
class PixShuffleRRDB(RRDBBase):
def __init__(self, nf, nb, gc=32, scale=2, rrdb_block_f=None):
super(PixShuffleRRDB, self).__init__()
# This class does a 4x pixel shuffle on the filter count inside the trunk, so nf must be divisible by 16.
assert nf % 16 == 0
# Trunk - does actual processing.
self.trunk = RRDBTrunk(3, nf, nb, gc, 4, rrdb_block_f)
self.trunks = [self.trunk]
# Upsampling
pix_nf = int(nf/16)
self.scale = scale
self.upconv1 = nn.Conv2d(pix_nf, pix_nf, 5, 1, padding=2, bias=True)
self.upconv2 = nn.Conv2d(pix_nf, pix_nf, 5, 1, padding=2, bias=True)
self.HRconv = nn.Conv2d(pix_nf, pix_nf, 5, 1, padding=2, bias=True)
self.conv_last = nn.Conv2d(pix_nf, 3, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(4)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.trunk(x)
fea = self.pixel_shuffle(fea)
if self.scale >= 2:
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
fea = self.lrelu(self.upconv1(fea))
if self.scale >= 4:
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
fea = self.lrelu(self.upconv2(fea))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return (out,)
# This class uses two RRDB trunks to process an image at different resolution levels.
class MultiRRDBNet(RRDBBase):
def __init__(self, nf_base, gc_base, lo_blocks, hi_blocks, scale=2, rrdb_block_f=None):
super(MultiRRDBNet, self).__init__()
# Initial downsampling.
self.conv_first = nn.Conv2d(3, nf_base, 5, stride=2, padding=2, bias=True)
# Chained trunks
lo_nf = nf_base * 4
hi_nf = nf_base
self.lo_trunk = RRDBTrunk(nf_base, lo_nf, lo_blocks, gc_base * 2, initial_stride=2, rrdb_block_f=rrdb_block_f)
self.hi_trunk = RRDBTrunk(nf_base, hi_nf, hi_blocks, gc_base, initial_stride=1, rrdb_block_f=rrdb_block_f)
self.trunks = [self.lo_trunk, self.hi_trunk]
# Upsampling
self.scale = scale
self.upconv1 = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
self.upconv2 = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
self.HRconv = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
self.conv_last = nn.Conv2d(hi_nf, 3, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(2)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
fea_lo = self.lo_trunk(fea)
fea = self.pixel_shuffle(fea_lo) + fea
fea = self.hi_trunk(fea)
# First, return image to original size and perform post-processing.
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
fea = self.lrelu(self.upconv1(fea))
# If 2x scaling is specified, do that too.
if self.scale >= 2:
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
fea = self.lrelu(self.upconv2(fea))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return (out,)

View File

@ -38,6 +38,21 @@ def define_G(opt, net_key='network_G'):
rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
init_temperature=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step']))
elif which_model == 'MultiRRDBNet':
block_f = None
if opt_net['attention']:
block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
init_temperature=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step'])
netG = RRDBNet_arch.MultiRRDBNet(nf_base=opt_net['nf'], gc_base=opt_net['gc'], lo_blocks=opt_net['lo_blocks'],
hi_blocks=opt_net['hi_blocks'], scale=scale, rrdb_block_f=block_f)
elif which_model == 'PixRRDBNet':
block_f = None
if opt_net['attention']:
block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
init_temperature=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step'])
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
elif which_model == 'ResGen':
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])