Fix mhead attention integration bug for RRDB
This commit is contained in:
parent
12e8fad079
commit
43b7fccc89
|
@ -32,17 +32,32 @@ class ResidualDenseBlock_5C(nn.Module):
|
|||
|
||||
|
||||
# Multiple 5-channel residual block that uses learned switching to diversify its outputs.
|
||||
class SwitchedRDB_5C(switched_conv.SwitchedAbstractBlock):
|
||||
def __init__(self, nf=64, gc=32, num_convs=8, init_temperature=1):
|
||||
class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock):
|
||||
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, init_temperature=1):
|
||||
rdb5c = functools.partial(ResidualDenseBlock_5C, nf, gc)
|
||||
super(SwitchedRDB_5C, self).__init__(
|
||||
rdb5c,
|
||||
nf,
|
||||
num_convs,
|
||||
num_heads,
|
||||
att_kernel_size=3,
|
||||
att_pads=1,
|
||||
initial_temperature=init_temperature,
|
||||
)
|
||||
self.mhead_collapse = nn.Conv2d(num_heads * nf, nf, 1)
|
||||
|
||||
arch_util.initialize_weights([sw.attention_conv1 for sw in self.switches] +
|
||||
[sw.attention_conv2 for sw in self.switches] +
|
||||
[self.mhead_collapse], 1)
|
||||
|
||||
def forward(self, x, output_attention_weights=False):
|
||||
outs = super(SwitchedRDB_5C, self).forward(x, output_attention_weights)
|
||||
if output_attention_weights:
|
||||
outs, atts = outs
|
||||
# outs need to be collapsed back down to a single heads worth of data.
|
||||
out = self.mhead_collapse(outs)
|
||||
return out, atts
|
||||
|
||||
arch_util.initialize_weights([self.switcher.attention_conv1, self.switcher.attention_conv2], 1)
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
|
@ -63,18 +78,18 @@ class RRDB(nn.Module):
|
|||
class AttentiveRRDB(RRDB):
|
||||
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
||||
self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
||||
self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
||||
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature)
|
||||
self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature)
|
||||
self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature)
|
||||
self.init_temperature = init_temperature
|
||||
self.final_temperature_step = final_temperature_step
|
||||
self.running_mean = 0
|
||||
self.running_count = 0
|
||||
|
||||
def set_temperature(self, temp):
|
||||
self.RDB1.switcher.set_attention_temperature(temp)
|
||||
self.RDB2.switcher.set_attention_temperature(temp)
|
||||
self.RDB3.switcher.set_attention_temperature(temp)
|
||||
[sw.set_attention_temperature(temp) for sw in self.RDB1.switches]
|
||||
[sw.set_attention_temperature(temp) for sw in self.RDB2.switches]
|
||||
[sw.set_attention_temperature(temp) for sw in self.RDB3.switches]
|
||||
|
||||
def forward(self, x):
|
||||
out, att1 = self.RDB1(x, True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user