Fix mhead attention integration bug for RRDB

This commit is contained in:
James Betker 2020-06-10 12:02:33 -06:00
parent 12e8fad079
commit 43b7fccc89

View File

@ -32,17 +32,32 @@ class ResidualDenseBlock_5C(nn.Module):
# Multiple 5-channel residual block that uses learned switching to diversify its outputs.
class SwitchedRDB_5C(switched_conv.SwitchedAbstractBlock):
def __init__(self, nf=64, gc=32, num_convs=8, init_temperature=1):
class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock):
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, init_temperature=1):
rdb5c = functools.partial(ResidualDenseBlock_5C, nf, gc)
super(SwitchedRDB_5C, self).__init__(
rdb5c,
nf,
num_convs,
num_heads,
att_kernel_size=3,
att_pads=1,
initial_temperature=init_temperature,
)
self.mhead_collapse = nn.Conv2d(num_heads * nf, nf, 1)
arch_util.initialize_weights([sw.attention_conv1 for sw in self.switches] +
[sw.attention_conv2 for sw in self.switches] +
[self.mhead_collapse], 1)
def forward(self, x, output_attention_weights=False):
outs = super(SwitchedRDB_5C, self).forward(x, output_attention_weights)
if output_attention_weights:
outs, atts = outs
# outs need to be collapsed back down to a single heads worth of data.
out = self.mhead_collapse(outs)
return out, atts
arch_util.initialize_weights([self.switcher.attention_conv1, self.switcher.attention_conv2], 1)
class RRDB(nn.Module):
@ -63,18 +78,18 @@ class RRDB(nn.Module):
class AttentiveRRDB(RRDB):
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1):
super(RRDB, self).__init__()
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature)
self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature)
self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature)
self.init_temperature = init_temperature
self.final_temperature_step = final_temperature_step
self.running_mean = 0
self.running_count = 0
def set_temperature(self, temp):
self.RDB1.switcher.set_attention_temperature(temp)
self.RDB2.switcher.set_attention_temperature(temp)
self.RDB3.switcher.set_attention_temperature(temp)
[sw.set_attention_temperature(temp) for sw in self.RDB1.switches]
[sw.set_attention_temperature(temp) for sw in self.RDB2.switches]
[sw.set_attention_temperature(temp) for sw in self.RDB3.switches]
def forward(self, x):
out, att1 = self.RDB1(x, True)