Add serveral new RRDB architectures
This commit is contained in:
parent
296135ec18
commit
12e8fad079
|
@ -61,8 +61,6 @@ class RRDB(nn.Module):
|
|||
return out * 0.2 + x
|
||||
|
||||
class AttentiveRRDB(RRDB):
|
||||
counter = 0
|
||||
|
||||
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
||||
|
@ -72,8 +70,6 @@ class AttentiveRRDB(RRDB):
|
|||
self.final_temperature_step = final_temperature_step
|
||||
self.running_mean = 0
|
||||
self.running_count = 0
|
||||
self.counter = AttentiveRRDB.counter
|
||||
AttentiveRRDB.counter += 1
|
||||
|
||||
def set_temperature(self, temp):
|
||||
self.RDB1.switcher.set_attention_temperature(temp)
|
||||
|
@ -93,37 +89,28 @@ class AttentiveRRDB(RRDB):
|
|||
|
||||
return out * 0.2 + x
|
||||
|
||||
def get_debug_values(self, step):
|
||||
def get_debug_values(self, step, prefix):
|
||||
# Take the chance to update the temperature here.
|
||||
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
|
||||
self.set_temperature(temp)
|
||||
|
||||
# Intentionally overwrite attention_temperature from other RRDB blocks; these should be synced.
|
||||
val = {"RRDB_%i_attention_mean" % (self.counter,): self.running_mean / self.running_count,
|
||||
val = {"%s_attention_mean" % (prefix,): self.running_mean / self.running_count,
|
||||
"attention_temperature": temp}
|
||||
self.running_count = 0
|
||||
self.running_mean = 0
|
||||
return val
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1,
|
||||
rrdb_block_f=None):
|
||||
super(RRDBNet, self).__init__()
|
||||
# This module performs the majority of the processing done by RRDBNet. It just doesn't have the upsampling at the end.
|
||||
class RRDBTrunk(nn.Module):
|
||||
def __init__(self, nf_in, nf_out, nb, gc=32, initial_stride=1, rrdb_block_f=None):
|
||||
super(RRDBTrunk, self).__init__()
|
||||
if rrdb_block_f is None:
|
||||
rrdb_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
rrdb_block_f = functools.partial(RRDB, nf=nf_out, gc=gc)
|
||||
|
||||
self.scale = scale
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 7, initial_stride, padding=3, bias=True)
|
||||
self.conv_first = nn.Conv2d(nf_in, nf_out, 7, initial_stride, padding=3, bias=True)
|
||||
self.RRDB_trunk, self.rrdb_layers = arch_util.make_layer(rrdb_block_f, nb, True)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.trunk_conv = nn.Conv2d(nf_out, nf_out, 3, 1, 1, bias=True)
|
||||
|
||||
# Sets the softmax temperature of each RRDB layer. Only works if you are using attentive
|
||||
# convolutions.
|
||||
|
@ -135,6 +122,58 @@ class RRDBNet(nn.Module):
|
|||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
return fea
|
||||
|
||||
def get_debug_values(self, step, prefix):
|
||||
val = {}
|
||||
i = 0
|
||||
for block in self.RRDB_trunk._modules.values():
|
||||
if hasattr(block, "get_debug_values"):
|
||||
val.update(block.get_debug_values(step, "%s_rdb_%i" % (prefix, i)))
|
||||
i += 1
|
||||
return val
|
||||
|
||||
# Adds some base methods that all RRDB* classes will use.
|
||||
class RRDBBase(nn.Module):
|
||||
def __init__(self):
|
||||
super(RRDBBase, self).__init__()
|
||||
|
||||
# Sets the softmax temperature of each RRDB layer. Only works if you are using attentive
|
||||
# convolutions.
|
||||
def set_temperature(self, temp):
|
||||
for trunk in self.trunks:
|
||||
for layer in trunk.rrdb_layers:
|
||||
layer.set_temperature(temp)
|
||||
|
||||
def get_debug_values(self, step):
|
||||
val = {}
|
||||
for i, trunk in enumerate(self.trunks):
|
||||
for j, block in enumerate(trunk.RRDB_trunk._modules.values()):
|
||||
if hasattr(block, "get_debug_values"):
|
||||
val.update(block.get_debug_values(step, "trunk_%i_block_%i" % (i, j)))
|
||||
return val
|
||||
|
||||
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
|
||||
class RRDBNet(RRDBBase):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1,
|
||||
rrdb_block_f=None):
|
||||
super(RRDBNet, self).__init__()
|
||||
|
||||
# Trunk - does actual processing.
|
||||
self.trunk = RRDBTrunk(in_nc, nf, nb, gc, initial_stride, rrdb_block_f)
|
||||
self.trunks = [self.trunk]
|
||||
|
||||
# Upsampling
|
||||
self.scale = scale
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.trunk(x)
|
||||
|
||||
if self.scale >= 2:
|
||||
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
|
||||
|
@ -146,15 +185,16 @@ class RRDBNet(nn.Module):
|
|||
|
||||
return (out,)
|
||||
|
||||
def get_debug_values(self, step):
|
||||
val = {}
|
||||
for block in self.RRDB_trunk._modules.values():
|
||||
if hasattr(block, "get_debug_values"):
|
||||
val.update(block.get_debug_values(step))
|
||||
return val
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# The parameters in self.trunk used to be in this class. To support loading legacy saves, restore them.
|
||||
t_state = self.trunk.state_dict()
|
||||
for k in t_state.keys():
|
||||
state_dict["trunk.%s" % (k,)] = state_dict.pop(k)
|
||||
super(RRDBNet, self).load_state_dict(state_dict, strict)
|
||||
|
||||
# Variant of RRDBNet that is "assisted" by an external pretrained image classifier whose
|
||||
# intermediate layers have been splayed out, pixel-shuffled, and fed back in.
|
||||
# TODO: Convert to use new RRDBBase hierarchy.
|
||||
class AssistedRRDBNet(nn.Module):
|
||||
# in_nc=number of input channels.
|
||||
# out_nc=number of output channels.
|
||||
|
@ -171,10 +211,9 @@ class AssistedRRDBNet(nn.Module):
|
|||
# Set-up the assist-net, which should do feature extraction for us.
|
||||
self.assistnet = torchvision.models.wide_resnet50_2(pretrained=True)
|
||||
self.set_enable_assistnet_training(False)
|
||||
assist_nf = [2, 4, 8, 16] # Fixed for resnet. Re-evaluate if using other networks.
|
||||
self.assist1 = RRDB(nf + assist_nf[0], gc)
|
||||
self.assist2 = RRDB(nf + sum(assist_nf[:2]), gc)
|
||||
self.assist3 = RRDB(nf + sum(assist_nf[:3]), gc)
|
||||
assist_nf = [4, 8, 16] # Fixed for resnet. Re-evaluate if using other networks.
|
||||
self.assist2 = RRDB(nf + assist_nf[0], gc)
|
||||
self.assist3 = RRDB(nf + sum(assist_nf[:2]), gc)
|
||||
self.assist4 = RRDB(nf + sum(assist_nf), gc)
|
||||
nf = nf + sum(assist_nf)
|
||||
|
||||
|
@ -195,6 +234,11 @@ class AssistedRRDBNet(nn.Module):
|
|||
p.requires_grad = en
|
||||
|
||||
def res_extract(self, x):
|
||||
# Width and height must be factors of 16 to use this architecture. Check that here.
|
||||
(b, f, w, h) = x.shape
|
||||
assert w % 16 == 0
|
||||
assert h % 16 == 0
|
||||
|
||||
x = self.assistnet.conv1(x)
|
||||
x = self.assistnet.bn1(x)
|
||||
x = self.assistnet.relu(x)
|
||||
|
@ -206,16 +250,13 @@ class AssistedRRDBNet(nn.Module):
|
|||
l2 = F.pixel_shuffle(x, 8)
|
||||
x = self.assistnet.layer3(x)
|
||||
l3 = F.pixel_shuffle(x, 16)
|
||||
x = self.assistnet.layer4(x)
|
||||
l4 = F.pixel_shuffle(x, 32)
|
||||
return l1, l2, l3, l4
|
||||
return l1, l2, l3
|
||||
|
||||
def forward(self, x):
|
||||
# Invoke the assistant net first.
|
||||
l1, l2, l3, l4 = self.res_extract(x)
|
||||
l1, l2, l3 = self.res_extract(x)
|
||||
|
||||
fea = self.conv_first(x)
|
||||
fea = self.assist1(torch.cat([fea, l4], dim=1))
|
||||
fea = self.assist2(torch.cat([fea, l3], dim=1))
|
||||
fea = self.assist3(torch.cat([fea, l2], dim=1))
|
||||
fea = self.assist4(torch.cat([fea, l1], dim=1))
|
||||
|
@ -231,4 +272,85 @@ class AssistedRRDBNet(nn.Module):
|
|||
fea = self.lrelu(self.upconv2(fea))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return (out,)
|
||||
|
||||
|
||||
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
|
||||
class PixShuffleRRDB(RRDBBase):
|
||||
def __init__(self, nf, nb, gc=32, scale=2, rrdb_block_f=None):
|
||||
super(PixShuffleRRDB, self).__init__()
|
||||
|
||||
# This class does a 4x pixel shuffle on the filter count inside the trunk, so nf must be divisible by 16.
|
||||
assert nf % 16 == 0
|
||||
|
||||
# Trunk - does actual processing.
|
||||
self.trunk = RRDBTrunk(3, nf, nb, gc, 4, rrdb_block_f)
|
||||
self.trunks = [self.trunk]
|
||||
|
||||
# Upsampling
|
||||
pix_nf = int(nf/16)
|
||||
self.scale = scale
|
||||
self.upconv1 = nn.Conv2d(pix_nf, pix_nf, 5, 1, padding=2, bias=True)
|
||||
self.upconv2 = nn.Conv2d(pix_nf, pix_nf, 5, 1, padding=2, bias=True)
|
||||
self.HRconv = nn.Conv2d(pix_nf, pix_nf, 5, 1, padding=2, bias=True)
|
||||
self.conv_last = nn.Conv2d(pix_nf, 3, 3, 1, 1, bias=True)
|
||||
self.pixel_shuffle = nn.PixelShuffle(4)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.trunk(x)
|
||||
fea = self.pixel_shuffle(fea)
|
||||
|
||||
if self.scale >= 2:
|
||||
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
|
||||
fea = self.lrelu(self.upconv1(fea))
|
||||
if self.scale >= 4:
|
||||
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
|
||||
fea = self.lrelu(self.upconv2(fea))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return (out,)
|
||||
|
||||
|
||||
# This class uses two RRDB trunks to process an image at different resolution levels.
|
||||
class MultiRRDBNet(RRDBBase):
|
||||
def __init__(self, nf_base, gc_base, lo_blocks, hi_blocks, scale=2, rrdb_block_f=None):
|
||||
super(MultiRRDBNet, self).__init__()
|
||||
|
||||
# Initial downsampling.
|
||||
self.conv_first = nn.Conv2d(3, nf_base, 5, stride=2, padding=2, bias=True)
|
||||
|
||||
# Chained trunks
|
||||
lo_nf = nf_base * 4
|
||||
hi_nf = nf_base
|
||||
self.lo_trunk = RRDBTrunk(nf_base, lo_nf, lo_blocks, gc_base * 2, initial_stride=2, rrdb_block_f=rrdb_block_f)
|
||||
self.hi_trunk = RRDBTrunk(nf_base, hi_nf, hi_blocks, gc_base, initial_stride=1, rrdb_block_f=rrdb_block_f)
|
||||
self.trunks = [self.lo_trunk, self.hi_trunk]
|
||||
|
||||
# Upsampling
|
||||
self.scale = scale
|
||||
self.upconv1 = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
|
||||
self.upconv2 = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
|
||||
self.HRconv = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
|
||||
self.conv_last = nn.Conv2d(hi_nf, 3, 3, 1, 1, bias=True)
|
||||
self.pixel_shuffle = nn.PixelShuffle(2)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
fea_lo = self.lo_trunk(fea)
|
||||
fea = self.pixel_shuffle(fea_lo) + fea
|
||||
fea = self.hi_trunk(fea)
|
||||
|
||||
# First, return image to original size and perform post-processing.
|
||||
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
|
||||
fea = self.lrelu(self.upconv1(fea))
|
||||
|
||||
# If 2x scaling is specified, do that too.
|
||||
if self.scale >= 2:
|
||||
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
|
||||
fea = self.lrelu(self.upconv2(fea))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return (out,)
|
|
@ -38,6 +38,21 @@ def define_G(opt, net_key='network_G'):
|
|||
rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
init_temperature=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step']))
|
||||
elif which_model == 'MultiRRDBNet':
|
||||
block_f = None
|
||||
if opt_net['attention']:
|
||||
block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
init_temperature=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step'])
|
||||
netG = RRDBNet_arch.MultiRRDBNet(nf_base=opt_net['nf'], gc_base=opt_net['gc'], lo_blocks=opt_net['lo_blocks'],
|
||||
hi_blocks=opt_net['hi_blocks'], scale=scale, rrdb_block_f=block_f)
|
||||
elif which_model == 'PixRRDBNet':
|
||||
block_f = None
|
||||
if opt_net['attention']:
|
||||
block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
init_temperature=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step'])
|
||||
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
|
||||
elif which_model == 'ResGen':
|
||||
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
|
||||
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])
|
||||
|
|
Loading…
Reference in New Issue
Block a user