From 43b7fccc891dc5fd81b441e24f0c953ea0925814 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 10 Jun 2020 12:02:33 -0600 Subject: [PATCH] Fix mhead attention integration bug for RRDB --- codes/models/archs/RRDBNet_arch.py | 33 ++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 2f56e4f5..f0517df4 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -32,17 +32,32 @@ class ResidualDenseBlock_5C(nn.Module): # Multiple 5-channel residual block that uses learned switching to diversify its outputs. -class SwitchedRDB_5C(switched_conv.SwitchedAbstractBlock): - def __init__(self, nf=64, gc=32, num_convs=8, init_temperature=1): +class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock): + def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, init_temperature=1): rdb5c = functools.partial(ResidualDenseBlock_5C, nf, gc) super(SwitchedRDB_5C, self).__init__( rdb5c, nf, num_convs, + num_heads, + att_kernel_size=3, + att_pads=1, initial_temperature=init_temperature, ) + self.mhead_collapse = nn.Conv2d(num_heads * nf, nf, 1) + + arch_util.initialize_weights([sw.attention_conv1 for sw in self.switches] + + [sw.attention_conv2 for sw in self.switches] + + [self.mhead_collapse], 1) + + def forward(self, x, output_attention_weights=False): + outs = super(SwitchedRDB_5C, self).forward(x, output_attention_weights) + if output_attention_weights: + outs, atts = outs + # outs need to be collapsed back down to a single heads worth of data. + out = self.mhead_collapse(outs) + return out, atts - arch_util.initialize_weights([self.switcher.attention_conv1, self.switcher.attention_conv2], 1) class RRDB(nn.Module): @@ -63,18 +78,18 @@ class RRDB(nn.Module): class AttentiveRRDB(RRDB): def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1): super(RRDB, self).__init__() - self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature) - self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature) - self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature) + self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature) + self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature) + self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature) self.init_temperature = init_temperature self.final_temperature_step = final_temperature_step self.running_mean = 0 self.running_count = 0 def set_temperature(self, temp): - self.RDB1.switcher.set_attention_temperature(temp) - self.RDB2.switcher.set_attention_temperature(temp) - self.RDB3.switcher.set_attention_temperature(temp) + [sw.set_attention_temperature(temp) for sw in self.RDB1.switches] + [sw.set_attention_temperature(temp) for sw in self.RDB2.switches] + [sw.set_attention_temperature(temp) for sw in self.RDB3.switches] def forward(self, x): out, att1 = self.RDB1(x, True)