Add skip heads to switcher
These pass through the input so that it can be selected by the attention mechanism.
This commit is contained in:
parent
6c27ddc9b5
commit
be7982b9ae
|
@ -37,7 +37,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
|||
# If collapse_heads=True, outputs (b,f,w,h) tensor.
|
||||
# If collapse_heads=False, outputs (b,heads,f,w,h) tensor.
|
||||
class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock):
|
||||
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, init_temperature=1, multi_head_input=False, collapse_heads=True, force_block=None):
|
||||
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, include_skip_head=False, init_temperature=1, multi_head_input=False, collapse_heads=True, force_block=None):
|
||||
if force_block is None:
|
||||
rdb5c = functools.partial(ResidualDenseBlock_5C, nf, gc)
|
||||
else:
|
||||
|
@ -49,6 +49,7 @@ class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock):
|
|||
num_heads,
|
||||
att_kernel_size=3,
|
||||
att_pads=1,
|
||||
include_skip_head=include_skip_head,
|
||||
initial_temperature=init_temperature,
|
||||
multi_head_input=multi_head_input,
|
||||
concat_heads_into_filters=collapse_heads,
|
||||
|
@ -100,13 +101,14 @@ class ResidualDenseBlock_5C_WithMheadConverter(ResidualDenseBlock_5C):
|
|||
# It does this by performing a Conv3d on the first block, which convolves all heads and collapses them to a dimension
|
||||
# of 1. The tensor is then squeezed and performs identically to SwitchedRDB_5C from there.
|
||||
class SwitchedRDB_5C_MultiHead(SwitchedRDB_5C):
|
||||
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, init_temperature=1, collapse_heads=False):
|
||||
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, include_skip_head=False, init_temperature=1, collapse_heads=False):
|
||||
rdb5c = functools.partial(ResidualDenseBlock_5C_WithMheadConverter, nf, gc, heads=num_heads)
|
||||
super(SwitchedRDB_5C_MultiHead, self).__init__(
|
||||
nf=nf,
|
||||
gc=gc,
|
||||
num_convs=num_convs,
|
||||
num_heads=num_heads,
|
||||
include_skip_head=include_skip_head,
|
||||
init_temperature=init_temperature,
|
||||
multi_head_input=True,
|
||||
collapse_heads=collapse_heads,
|
||||
|
@ -205,10 +207,10 @@ class LowDimRRDBWrapper(nn.Module):
|
|||
class SwitchedMultiHeadRRDB(SwitchedRRDB):
|
||||
def __init__(self, nf, gc=32, num_convs=8, num_heads=2, init_temperature=1, final_temperature_step=1):
|
||||
super(SwitchedMultiHeadRRDB, self).__init__(nf=nf, gc=gc, num_convs=num_convs, init_temperature=init_temperature, final_temperature_step=final_temperature_step)
|
||||
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=False)
|
||||
self.RDB2 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=False)
|
||||
self.RDB3 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=False)
|
||||
self.RDB4 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=True)
|
||||
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False)
|
||||
self.RDB2 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False)
|
||||
self.RDB3 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False)
|
||||
self.RDB4 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=True)
|
||||
|
||||
def set_temperature(self, temp):
|
||||
[sw.set_attention_temperature(temp) for sw in self.RDB1.switches]
|
||||
|
|
|
@ -57,6 +57,10 @@ def define_G(opt, net_key='network_G'):
|
|||
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
init_temperature=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step'])
|
||||
if opt_net['mhattention']:
|
||||
block_f = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=8, num_heads=2, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
init_temperature=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step'])
|
||||
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
|
||||
elif which_model == 'ResGen':
|
||||
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
|
||||
|
|
Loading…
Reference in New Issue
Block a user