From 063719c5ccb8d4e6adf534b09a2287bf4c885210 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 6 Jun 2020 18:29:25 -0600 Subject: [PATCH] Fix attention conv bugs --- codes/models/archs/RRDBNet_arch.py | 10 +++++----- codes/models/archs/arch_util.py | 8 +++++--- codes/models/networks.py | 5 +++++ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 24de959d..3255fc8d 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -34,15 +34,15 @@ class AttentiveResidualDenseBlock_5C(ResidualDenseBlock_5C): def __init__(self, nf=64, gc=32, num_convs=8, init_temperature=1): super(AttentiveResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels - self.conv1 = arch_util.DynamicConv2d(nf, gc, 3, 1, 1, bias=bias, num_convs=num_convs, + self.conv1 = arch_util.DynamicConv2d(nf, gc, 3, 1, 1, num_convs=num_convs, initial_temperature=init_temperature) - self.conv2 = arch_util.DynamicConv2d(nf + gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs, + self.conv2 = arch_util.DynamicConv2d(nf + gc, gc, 3, 1, 1, num_convs=num_convs, initial_temperature=init_temperature) - self.conv3 = arch_util.DynamicConv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs, + self.conv3 = arch_util.DynamicConv2d(nf + 2 * gc, gc, 3, 1, 1, num_convs=num_convs, initial_temperature=init_temperature) - self.conv4 = arch_util.DynamicConv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs, + self.conv4 = arch_util.DynamicConv2d(nf + 3 * gc, gc, 3, 1, 1, num_convs=num_convs, initial_temperature=init_temperature) - self.conv5 = arch_util.DynamicConv2d(nf + 4 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs, + self.conv5 = arch_util.DynamicConv2d(nf + 4 * gc, nf, 3, 1, 1, num_convs=num_convs, initial_temperature=init_temperature) # initialization diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index fa0d92f5..e146fe7d 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -44,7 +44,7 @@ class DynamicConv2d(nn.Module): # Requirements: input filter count is even, and there are more filters than there are sequences to attend to. assert nf_in_per_conv % 2 == 0 - assert nf_in_per_conv / 2 > num_convs + assert nf_in_per_conv / 2 >= num_convs self.nf = nf_out_per_conv self.num_convs = num_convs @@ -75,7 +75,9 @@ class DynamicConv2d(nn.Module): # conv_attention shape: (batch, width, height, sequences) # We want to format them so that we can matmul them together to produce: # desired shape: (batch, width, height, filters) - attention_result = torch.einsum("...ij,...j->...i", [conv_outputs, conv_attention]) + # Note: conv_attention will generally be cast to float32 regardless of the input type, so cast conv_outputs to + # float32 as well to match it. + attention_result = torch.einsum("...ij,...j->...i", [conv_outputs.to(dtype=torch.float32), conv_attention]) # Remember to shift the filters back into the expected slot. if output_attention_weights: @@ -175,7 +177,7 @@ from torch.utils.tensorboard import SummaryWriter def test_dynamic_conv(): writer = SummaryWriter() - dataset = datasets.ImageFolder("E:\\data\\cifar-100-python\\images\\train", transforms.Compose([ + dataset = datasets.ImageFolder("C:\\data\\cifar-100-python\\images\\train", transforms.Compose([ transforms.Resize(32, 32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), diff --git a/codes/models/networks.py b/codes/models/networks.py index bce2e531..f8635fc7 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -9,6 +9,7 @@ import models.archs.HighToLowResNet as HighToLowResNet import models.archs.ResGen_arch as ResGen_arch import models.archs.biggan_gen_arch as biggan_arch import models.archs.feature_arch as feature_arch +import functools # Generator def define_G(opt, net_key='network_G'): @@ -31,6 +32,10 @@ def define_G(opt, net_key='network_G'): elif which_model == 'AssistedRRDBNet': netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], scale=scale) + elif which_model == 'AttentiveRRDBNet': + netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, + rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'])) elif which_model == 'ResGen': netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'], upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])