Fix mhead attention integration bug for RRDB
This commit is contained in:
parent
12e8fad079
commit
43b7fccc89
|
@ -32,17 +32,32 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
# Multiple 5-channel residual block that uses learned switching to diversify its outputs.
|
# Multiple 5-channel residual block that uses learned switching to diversify its outputs.
|
||||||
class SwitchedRDB_5C(switched_conv.SwitchedAbstractBlock):
|
class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock):
|
||||||
def __init__(self, nf=64, gc=32, num_convs=8, init_temperature=1):
|
def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, init_temperature=1):
|
||||||
rdb5c = functools.partial(ResidualDenseBlock_5C, nf, gc)
|
rdb5c = functools.partial(ResidualDenseBlock_5C, nf, gc)
|
||||||
super(SwitchedRDB_5C, self).__init__(
|
super(SwitchedRDB_5C, self).__init__(
|
||||||
rdb5c,
|
rdb5c,
|
||||||
nf,
|
nf,
|
||||||
num_convs,
|
num_convs,
|
||||||
|
num_heads,
|
||||||
|
att_kernel_size=3,
|
||||||
|
att_pads=1,
|
||||||
initial_temperature=init_temperature,
|
initial_temperature=init_temperature,
|
||||||
)
|
)
|
||||||
|
self.mhead_collapse = nn.Conv2d(num_heads * nf, nf, 1)
|
||||||
|
|
||||||
|
arch_util.initialize_weights([sw.attention_conv1 for sw in self.switches] +
|
||||||
|
[sw.attention_conv2 for sw in self.switches] +
|
||||||
|
[self.mhead_collapse], 1)
|
||||||
|
|
||||||
|
def forward(self, x, output_attention_weights=False):
|
||||||
|
outs = super(SwitchedRDB_5C, self).forward(x, output_attention_weights)
|
||||||
|
if output_attention_weights:
|
||||||
|
outs, atts = outs
|
||||||
|
# outs need to be collapsed back down to a single heads worth of data.
|
||||||
|
out = self.mhead_collapse(outs)
|
||||||
|
return out, atts
|
||||||
|
|
||||||
arch_util.initialize_weights([self.switcher.attention_conv1, self.switcher.attention_conv2], 1)
|
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
class RRDB(nn.Module):
|
||||||
|
@ -63,18 +78,18 @@ class RRDB(nn.Module):
|
||||||
class AttentiveRRDB(RRDB):
|
class AttentiveRRDB(RRDB):
|
||||||
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1):
|
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1):
|
||||||
super(RRDB, self).__init__()
|
super(RRDB, self).__init__()
|
||||||
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature)
|
||||||
self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
self.RDB2 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature)
|
||||||
self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature)
|
self.RDB3 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature)
|
||||||
self.init_temperature = init_temperature
|
self.init_temperature = init_temperature
|
||||||
self.final_temperature_step = final_temperature_step
|
self.final_temperature_step = final_temperature_step
|
||||||
self.running_mean = 0
|
self.running_mean = 0
|
||||||
self.running_count = 0
|
self.running_count = 0
|
||||||
|
|
||||||
def set_temperature(self, temp):
|
def set_temperature(self, temp):
|
||||||
self.RDB1.switcher.set_attention_temperature(temp)
|
[sw.set_attention_temperature(temp) for sw in self.RDB1.switches]
|
||||||
self.RDB2.switcher.set_attention_temperature(temp)
|
[sw.set_attention_temperature(temp) for sw in self.RDB2.switches]
|
||||||
self.RDB3.switcher.set_attention_temperature(temp)
|
[sw.set_attention_temperature(temp) for sw in self.RDB3.switches]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out, att1 = self.RDB1(x, True)
|
out, att1 = self.RDB1(x, True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user