Add skip heads to switcher

These pass through the input so that it can be selected by the attention mechanism.
This commit is contained in:
James Betker 2020-06-14 12:46:54 -06:00
parent 6c27ddc9b5
commit be7982b9ae
2 changed files with 12 additions and 6 deletions

View File

@ -37,7 +37,7 @@ class ResidualDenseBlock_5C(nn.Module):
# If collapse_heads=True, outputs (b,f,w,h) tensor.
# If collapse_heads=False, outputs (b,heads,f,w,h) tensor.
class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock):
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, init_temperature=1, multi_head_input=False, collapse_heads=True, force_block=None):
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, include_skip_head=False, init_temperature=1, multi_head_input=False, collapse_heads=True, force_block=None):
if force_block is None:
rdb5c = functools.partial(ResidualDenseBlock_5C, nf, gc)
else:
@ -49,6 +49,7 @@ class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock):
num_heads,
att_kernel_size=3,
att_pads=1,
include_skip_head=include_skip_head,
initial_temperature=init_temperature,
multi_head_input=multi_head_input,
concat_heads_into_filters=collapse_heads,
@ -100,13 +101,14 @@ class ResidualDenseBlock_5C_WithMheadConverter(ResidualDenseBlock_5C):
# It does this by performing a Conv3d on the first block, which convolves all heads and collapses them to a dimension
# of 1. The tensor is then squeezed and performs identically to SwitchedRDB_5C from there.
class SwitchedRDB_5C_MultiHead(SwitchedRDB_5C):
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, init_temperature=1, collapse_heads=False):
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, include_skip_head=False, init_temperature=1, collapse_heads=False):
rdb5c = functools.partial(ResidualDenseBlock_5C_WithMheadConverter, nf, gc, heads=num_heads)
super(SwitchedRDB_5C_MultiHead, self).__init__(
nf=nf,
gc=gc,
num_convs=num_convs,
num_heads=num_heads,
include_skip_head=include_skip_head,
init_temperature=init_temperature,
multi_head_input=True,
collapse_heads=collapse_heads,
@ -205,10 +207,10 @@ class LowDimRRDBWrapper(nn.Module):
class SwitchedMultiHeadRRDB(SwitchedRRDB):
def __init__(self, nf, gc=32, num_convs=8, num_heads=2, init_temperature=1, final_temperature_step=1):
super(SwitchedMultiHeadRRDB, self).__init__(nf=nf, gc=gc, num_convs=num_convs, init_temperature=init_temperature, final_temperature_step=final_temperature_step)
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=False)
self.RDB2 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=False)
self.RDB3 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=False)
self.RDB4 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=True)
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False)
self.RDB2 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False)
self.RDB3 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False)
self.RDB4 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=True)
def set_temperature(self, temp):
[sw.set_attention_temperature(temp) for sw in self.RDB1.switches]

View File

@ -57,6 +57,10 @@ def define_G(opt, net_key='network_G'):
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
init_temperature=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step'])
if opt_net['mhattention']:
block_f = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=8, num_heads=2, nf=opt_net['nf'], gc=opt_net['gc'],
init_temperature=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step'])
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
elif which_model == 'ResGen':
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],