forked from mrq/DL-Art-School
Add skip heads to switcher
These pass through the input so that it can be selected by the attention mechanism.
This commit is contained in:
parent
6c27ddc9b5
commit
be7982b9ae
|
@ -37,7 +37,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||||
# If collapse_heads=True, outputs (b,f,w,h) tensor.
|
# If collapse_heads=True, outputs (b,f,w,h) tensor.
|
||||||
# If collapse_heads=False, outputs (b,heads,f,w,h) tensor.
|
# If collapse_heads=False, outputs (b,heads,f,w,h) tensor.
|
||||||
class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock):
|
class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock):
|
||||||
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, init_temperature=1, multi_head_input=False, collapse_heads=True, force_block=None):
|
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, include_skip_head=False, init_temperature=1, multi_head_input=False, collapse_heads=True, force_block=None):
|
||||||
if force_block is None:
|
if force_block is None:
|
||||||
rdb5c = functools.partial(ResidualDenseBlock_5C, nf, gc)
|
rdb5c = functools.partial(ResidualDenseBlock_5C, nf, gc)
|
||||||
else:
|
else:
|
||||||
|
@ -49,6 +49,7 @@ class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock):
|
||||||
num_heads,
|
num_heads,
|
||||||
att_kernel_size=3,
|
att_kernel_size=3,
|
||||||
att_pads=1,
|
att_pads=1,
|
||||||
|
include_skip_head=include_skip_head,
|
||||||
initial_temperature=init_temperature,
|
initial_temperature=init_temperature,
|
||||||
multi_head_input=multi_head_input,
|
multi_head_input=multi_head_input,
|
||||||
concat_heads_into_filters=collapse_heads,
|
concat_heads_into_filters=collapse_heads,
|
||||||
|
@ -100,13 +101,14 @@ class ResidualDenseBlock_5C_WithMheadConverter(ResidualDenseBlock_5C):
|
||||||
# It does this by performing a Conv3d on the first block, which convolves all heads and collapses them to a dimension
|
# It does this by performing a Conv3d on the first block, which convolves all heads and collapses them to a dimension
|
||||||
# of 1. The tensor is then squeezed and performs identically to SwitchedRDB_5C from there.
|
# of 1. The tensor is then squeezed and performs identically to SwitchedRDB_5C from there.
|
||||||
class SwitchedRDB_5C_MultiHead(SwitchedRDB_5C):
|
class SwitchedRDB_5C_MultiHead(SwitchedRDB_5C):
|
||||||
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, init_temperature=1, collapse_heads=False):
|
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, include_skip_head=False, init_temperature=1, collapse_heads=False):
|
||||||
rdb5c = functools.partial(ResidualDenseBlock_5C_WithMheadConverter, nf, gc, heads=num_heads)
|
rdb5c = functools.partial(ResidualDenseBlock_5C_WithMheadConverter, nf, gc, heads=num_heads)
|
||||||
super(SwitchedRDB_5C_MultiHead, self).__init__(
|
super(SwitchedRDB_5C_MultiHead, self).__init__(
|
||||||
nf=nf,
|
nf=nf,
|
||||||
gc=gc,
|
gc=gc,
|
||||||
num_convs=num_convs,
|
num_convs=num_convs,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
|
include_skip_head=include_skip_head,
|
||||||
init_temperature=init_temperature,
|
init_temperature=init_temperature,
|
||||||
multi_head_input=True,
|
multi_head_input=True,
|
||||||
collapse_heads=collapse_heads,
|
collapse_heads=collapse_heads,
|
||||||
|
@ -205,10 +207,10 @@ class LowDimRRDBWrapper(nn.Module):
|
||||||
class SwitchedMultiHeadRRDB(SwitchedRRDB):
|
class SwitchedMultiHeadRRDB(SwitchedRRDB):
|
||||||
def __init__(self, nf, gc=32, num_convs=8, num_heads=2, init_temperature=1, final_temperature_step=1):
|
def __init__(self, nf, gc=32, num_convs=8, num_heads=2, init_temperature=1, final_temperature_step=1):
|
||||||
super(SwitchedMultiHeadRRDB, self).__init__(nf=nf, gc=gc, num_convs=num_convs, init_temperature=init_temperature, final_temperature_step=final_temperature_step)
|
super(SwitchedMultiHeadRRDB, self).__init__(nf=nf, gc=gc, num_convs=num_convs, init_temperature=init_temperature, final_temperature_step=final_temperature_step)
|
||||||
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=False)
|
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False)
|
||||||
self.RDB2 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=False)
|
self.RDB2 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False)
|
||||||
self.RDB3 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=False)
|
self.RDB3 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False)
|
||||||
self.RDB4 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=True)
|
self.RDB4 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=True)
|
||||||
|
|
||||||
def set_temperature(self, temp):
|
def set_temperature(self, temp):
|
||||||
[sw.set_attention_temperature(temp) for sw in self.RDB1.switches]
|
[sw.set_attention_temperature(temp) for sw in self.RDB1.switches]
|
||||||
|
|
|
@ -57,6 +57,10 @@ def define_G(opt, net_key='network_G'):
|
||||||
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||||
init_temperature=opt_net['temperature'],
|
init_temperature=opt_net['temperature'],
|
||||||
final_temperature_step=opt_net['temperature_final_step'])
|
final_temperature_step=opt_net['temperature_final_step'])
|
||||||
|
if opt_net['mhattention']:
|
||||||
|
block_f = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=8, num_heads=2, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||||
|
init_temperature=opt_net['temperature'],
|
||||||
|
final_temperature_step=opt_net['temperature_final_step'])
|
||||||
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
|
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
|
||||||
elif which_model == 'ResGen':
|
elif which_model == 'ResGen':
|
||||||
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
|
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user