diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 7be467d5..fc7725ce 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -37,7 +37,7 @@ class ResidualDenseBlock_5C(nn.Module): # If collapse_heads=True, outputs (b,f,w,h) tensor. # If collapse_heads=False, outputs (b,heads,f,w,h) tensor. class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock): - def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, init_temperature=1, multi_head_input=False, collapse_heads=True, force_block=None): + def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, include_skip_head=False, init_temperature=1, multi_head_input=False, collapse_heads=True, force_block=None): if force_block is None: rdb5c = functools.partial(ResidualDenseBlock_5C, nf, gc) else: @@ -49,6 +49,7 @@ class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock): num_heads, att_kernel_size=3, att_pads=1, + include_skip_head=include_skip_head, initial_temperature=init_temperature, multi_head_input=multi_head_input, concat_heads_into_filters=collapse_heads, @@ -100,13 +101,14 @@ class ResidualDenseBlock_5C_WithMheadConverter(ResidualDenseBlock_5C): # It does this by performing a Conv3d on the first block, which convolves all heads and collapses them to a dimension # of 1. The tensor is then squeezed and performs identically to SwitchedRDB_5C from there. class SwitchedRDB_5C_MultiHead(SwitchedRDB_5C): - def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, init_temperature=1, collapse_heads=False): + def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, include_skip_head=False, init_temperature=1, collapse_heads=False): rdb5c = functools.partial(ResidualDenseBlock_5C_WithMheadConverter, nf, gc, heads=num_heads) super(SwitchedRDB_5C_MultiHead, self).__init__( nf=nf, gc=gc, num_convs=num_convs, num_heads=num_heads, + include_skip_head=include_skip_head, init_temperature=init_temperature, multi_head_input=True, collapse_heads=collapse_heads, @@ -205,10 +207,10 @@ class LowDimRRDBWrapper(nn.Module): class SwitchedMultiHeadRRDB(SwitchedRRDB): def __init__(self, nf, gc=32, num_convs=8, num_heads=2, init_temperature=1, final_temperature_step=1): super(SwitchedMultiHeadRRDB, self).__init__(nf=nf, gc=gc, num_convs=num_convs, init_temperature=init_temperature, final_temperature_step=final_temperature_step) - self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=False) - self.RDB2 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=False) - self.RDB3 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=False) - self.RDB4 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, init_temperature=init_temperature, collapse_heads=True) + self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False) + self.RDB2 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False) + self.RDB3 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False) + self.RDB4 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=True) def set_temperature(self, temp): [sw.set_attention_temperature(temp) for sw in self.RDB1.switches] diff --git a/codes/models/networks.py b/codes/models/networks.py index 72ebd8b1..cc1fb6e0 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -57,6 +57,10 @@ def define_G(opt, net_key='network_G'): block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'], init_temperature=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step']) + if opt_net['mhattention']: + block_f = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=8, num_heads=2, nf=opt_net['nf'], gc=opt_net['gc'], + init_temperature=opt_net['temperature'], + final_temperature_step=opt_net['temperature_final_step']) netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f) elif which_model == 'ResGen': netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],