Add RRDB with attention

This commit is contained in:
James Betker 2020-06-05 21:02:08 -06:00
parent ef5d8a0ed1
commit cbedd6340a
2 changed files with 279 additions and 5 deletions

View File

@ -29,6 +29,33 @@ class ResidualDenseBlock_5C(nn.Module):
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
# 5-channel residual block that uses attention in the convolutions.
class AttentiveResidualDenseBlock_5C(ResidualDenseBlock_5C):
def __init__(self, nf=64, gc=32, num_convs=8, init_temperature=1):
super(AttentiveResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = arch_util.DynamicConv2d(nf, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
initial_temperature=init_temperature)
self.conv2 = arch_util.DynamicConv2d(nf + gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
initial_temperature=init_temperature)
self.conv3 = arch_util.DynamicConv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
initial_temperature=init_temperature)
self.conv4 = arch_util.DynamicConv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
initial_temperature=init_temperature)
self.conv5 = arch_util.DynamicConv2d(nf + 4 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
initial_temperature=init_temperature)
# initialization
arch_util.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5],
0.1)
def set_temperature(self, temp):
self.conv1.set_attention_temperature(temp)
self.conv2.set_attention_temperature(temp)
self.conv3.set_attention_temperature(temp)
self.conv4.set_attention_temperature(temp)
self.conv5.set_attention_temperature(temp)
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
@ -45,16 +72,30 @@ class RRDB(nn.Module):
out = self.RDB3(out)
return out * 0.2 + x
class AttentiveRRDB(RRDB):
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1):
super(RRDB, self).__init__()
self.RDB1 = AttentiveResidualDenseBlock_5C(nf, gc, num_convs, init_temperature)
self.RDB2 = AttentiveResidualDenseBlock_5C(nf, gc, num_convs, init_temperature)
self.RDB3 = AttentiveResidualDenseBlock_5C(nf, gc, num_convs, init_temperature)
def set_temperature(self, temp):
self.RDB1.set_temperature(temp)
self.RDB2.set_temperature(temp)
self.RDB3.set_temperature(temp)
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1,
rrdb_block_f=None):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
if rrdb_block_f is None:
rrdb_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.scale = scale
self.conv_first = nn.Conv2d(in_nc, nf, 7, initial_stride, padding=3, bias=True)
self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb)
self.RRDB_trunk, self.rrdb_layers = arch_util.make_layer(rrdb_block_f, nb, True)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
self.upconv2 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
@ -63,6 +104,12 @@ class RRDBNet(nn.Module):
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# Sets the softmax temperature of each RRDB layer. Only works if you are using attentive
# convolutions.
def set_temperature(self, temp):
for layer in self.rrdb_layers:
layer.set_temperature(temp)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))

View File

@ -28,11 +28,238 @@ def initialize_weights(net_l, scale=1):
init.constant_(m.bias.data, 0.0)
def make_layer(block, n_layers):
def make_layer(block, n_layers, return_layers=False):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
if return_layers:
return nn.Sequential(*layers), layers
else:
return nn.Sequential(*layers)
class DynamicConv2d(nn.Module):
def __init__(self, nf_in_per_conv, nf_out_per_conv, kernel_size, stride=1, pads=0, has_bias=True, num_convs=8,
att_kernel_size=5, att_pads=2, initial_temperature=1):
super(DynamicConv2d, self).__init__()
# Requirements: input filter count is even, and there are more filters than there are sequences to attend to.
assert nf_in_per_conv % 2 == 0
assert nf_in_per_conv / 2 > num_convs
self.nf = nf_out_per_conv
self.num_convs = num_convs
self.conv_list = nn.ModuleList([nn.Conv2d(nf_in_per_conv, nf_out_per_conv, kernel_size, stride, pads, bias=has_bias) for _ in range(num_convs)])
self.attention_conv1 = nn.Conv2d(nf_in_per_conv, int(nf_in_per_conv/2), att_kernel_size, stride, att_pads, bias=True)
self.att_bn1 = nn.BatchNorm2d(int(nf_in_per_conv/2))
self.attention_conv2 = nn.Conv2d(int(nf_in_per_conv/2), num_convs, att_kernel_size, 1, att_pads, bias=True)
self.softmax = nn.Softmax(dim=-1)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.temperature = initial_temperature
def set_attention_temperature(self, temp):
self.temperature = temp
def forward(self, x, output_attention_weights=False):
# Build up the individual conv components first.
conv_outputs = []
for conv in self.conv_list:
conv_outputs.append(conv.forward(x))
conv_outputs = torch.stack(conv_outputs, dim=0).permute(1, 3, 4, 2, 0)
# Now calculate the attention across those convs.
conv_attention = self.lrelu(self.att_bn1(self.attention_conv1(x)))
conv_attention = self.attention_conv2(conv_attention).permute(0, 2, 3, 1)
conv_attention = self.softmax(conv_attention / self.temperature)
# conv_outputs shape: (batch, width, height, filters, sequences)
# conv_attention shape: (batch, width, height, sequences)
# We want to format them so that we can matmul them together to produce:
# desired shape: (batch, width, height, filters)
attention_result = torch.einsum("...ij,...j->...i", [conv_outputs, conv_attention])
# Remember to shift the filters back into the expected slot.
if output_attention_weights:
return attention_result.permute(0, 3, 1, 2), conv_attention
else:
return attention_result.permute(0, 3, 1, 2)
def compute_attention_specificity(att_weights, topk=3):
att = att_weights.detach()
vals, indices = torch.topk(att, topk, dim=-1)
avg = torch.sum(vals, dim=-1)
avg = avg.flatten().mean()
return avg.item(), indices.flatten().detach()
class DynamicConvTestModule(nn.Module):
def __init__(self):
super(DynamicConvTestModule, self).__init__()
self.init_conv = nn.Conv2d(3, 16, 3, 1, 1, bias=True)
self.conv1 = DynamicConv2d(16, 32, 3, stride=2, pads=1, num_convs=4, initial_temperature=10)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = DynamicConv2d(32, 64, 3, stride=2, pads=1, att_kernel_size=3, att_pads=1, num_convs=8, initial_temperature=10)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = DynamicConv2d(64, 128, 3, stride=2, pads=1, att_kernel_size=3, att_pads=1, num_convs=16, initial_temperature=10)
self.bn3 = nn.BatchNorm2d(128)
self.relu = nn.ReLU()
self.dense1 = nn.Linear(128 * 4 * 4, 256)
self.dense2 = nn.Linear(256, 100)
self.softmax = nn.Softmax(-1)
def set_temp(self, temp):
self.conv1.set_attention_temperature(temp)
self.conv2.set_attention_temperature(temp)
self.conv3.set_attention_temperature(temp)
def forward(self, x):
x = self.init_conv(x)
x, att1 = self.conv1(x, output_attention_weights=True)
x = self.relu(self.bn1(x))
x, att2 = self.conv2(x, output_attention_weights=True)
x = self.relu(self.bn2(x))
x, att3 = self.conv3(x, output_attention_weights=True)
x = self.relu(self.bn3(x))
atts = [att1, att2, att3]
usage_hists = []
mean = 0
for a in atts:
m, u = compute_attention_specificity(a)
mean += m
usage_hists.append(u)
mean /= 3
x = x.flatten(1)
x = self.relu(self.dense1(x))
x = self.dense2(x)
# Compute metrics across attention weights.
return self.softmax(x), mean, usage_hists
class StandardConvTestModule(nn.Module):
def __init__(self):
super(StandardConvTestModule, self).__init__()
self.init_conv = nn.Conv2d(3, 16, 3, 1, 1, bias=True)
self.conv1 = nn.Conv2d(16, 64, 3, stride=2, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
self.bn2 = nn.BatchNorm2d(128)
self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
self.bn3 = nn.BatchNorm2d(256)
self.relu = nn.ReLU()
self.dense1 = nn.Linear(256 * 4 * 4, 256)
self.dense2 = nn.Linear(256, 100)
self.softmax = nn.Softmax(-1)
def set_temp(self, temp):
pass
def forward(self, x):
x = self.init_conv(x)
x = self.conv1(x)
x = self.relu(self.bn1(x))
x = self.conv2(x)
x = self.relu(self.bn2(x))
x = self.conv3(x)
x = self.relu(self.bn3(x))
x = x.flatten(1)
x = self.relu(self.dense1(x))
x = self.dense2(x)
return self.softmax(x), 0, []
import torch.optim as optim
from torchvision import datasets, models, transforms
import tqdm
from torch.utils.tensorboard import SummaryWriter
def test_dynamic_conv():
writer = SummaryWriter()
dataset = datasets.ImageFolder("E:\\data\\cifar-100-python\\images\\train", transforms.Compose([
transforms.Resize(32, 32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
batch_size = 256
temperature = 30
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
device = torch.device("cuda:0")
net = StandardConvTestModule()
net = net.to(device)
net.set_temp(temperature)
initialize_weights(net)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
# Load state, where necessary.
'''
netstate, optimstate = torch.load("test_net.pth")
net.load_state_dict(netstate)
optimizer.load_state_dict(optimstate)
'''
criterion = nn.CrossEntropyLoss()
step = 0
running_corrects = 0
running_att_mean = 0
running_att_hist = None
for e in range(300):
tq = tqdm.tqdm(loader)
for batch, labels in tq:
batch = batch.to(device)
labels = labels.to(device)
optimizer.zero_grad()
logits, att_mean, att_usage_hist = net.forward(batch)
running_att_mean += att_mean
if running_att_hist is None:
running_att_hist = att_usage_hist
else:
for i in range(len(att_usage_hist)):
running_att_hist[i] = torch.cat([running_att_hist[i], att_usage_hist[i]]).flatten()
loss = criterion(logits, labels)
loss.backward()
'''
if step % 50 == 0:
c1_grad_avg = sum([m.weight.grad.abs().mean().item() for m in net.conv1.conv_list._modules.values()]) / len(net.conv1.conv_list._modules)
c1a_grad_avg = (net.conv1.attention_conv1.weight.grad.abs().mean() + net.conv1.attention_conv2.weight.grad.abs().mean()) / 2
c2_grad_avg = sum([m.weight.grad.abs().mean().item() for m in net.conv2.conv_list._modules.values()]) / len(net.conv2.conv_list._modules)
c2a_grad_avg = (net.conv2.attention_conv1.weight.grad.abs().mean() + net.conv2.attention_conv2.weight.grad.abs().mean()) / 2
c3_grad_avg = sum([m.weight.grad.abs().mean().item() for m in net.conv3.conv_list._modules.values()]) / len(net.conv3.conv_list._modules)
c3a_grad_avg = (net.conv3.attention_conv1.weight.grad.abs().mean() + net.conv3.attention_conv2.weight.grad.abs().mean()) / 2
writer.add_scalar("c1_grad_avg", c1_grad_avg, global_step=step)
writer.add_scalar("c2_grad_avg", c2_grad_avg, global_step=step)
writer.add_scalar("c3_grad_avg", c3_grad_avg, global_step=step)
writer.add_scalar("c1a_grad_avg", c1a_grad_avg, global_step=step)
writer.add_scalar("c2a_grad_avg", c2a_grad_avg, global_step=step)
writer.add_scalar("c3a_grad_avg", c3a_grad_avg, global_step=step)
'''
optimizer.step()
_, preds = torch.max(logits, 1)
running_corrects += torch.sum(preds == labels.data)
if step % 50 == 0:
print("Step: %i, Loss: %f, acc: %f, att_mean: %f" % (step, loss.item(), running_corrects / (50.0 * batch_size),
running_att_mean / 50.0))
writer.add_scalar("Loss", loss.item(), global_step=step)
writer.add_scalar("Accuracy", running_corrects / (50.0 * batch_size), global_step=step)
writer.add_scalar("Att Mean", running_att_mean / 50, global_step=step)
for i in range(len(running_att_hist)):
writer.add_histogram("Att Hist %i" % (i,), running_att_hist[i], global_step=step)
writer.flush()
running_corrects = 0
running_att_mean = 0
running_att_hist = None
if step % 1000 == 0:
temperature = max(temperature-1, 1)
net.set_temp(temperature)
print("Temperature drop. Now: %i" % (temperature,))
step += 1
torch.save((net.state_dict(), optimizer.state_dict()), "test_net_standard.pth")
if __name__ == '__main__':
test_dynamic_conv()
class ResidualBlock(nn.Module):
'''Residual block with BN