Fix attention conv bugs

This commit is contained in:
James Betker 2020-06-06 18:29:25 -06:00
parent cbedd6340a
commit 063719c5cc
3 changed files with 15 additions and 8 deletions

View File

@ -34,15 +34,15 @@ class AttentiveResidualDenseBlock_5C(ResidualDenseBlock_5C):
def __init__(self, nf=64, gc=32, num_convs=8, init_temperature=1): def __init__(self, nf=64, gc=32, num_convs=8, init_temperature=1):
super(AttentiveResidualDenseBlock_5C, self).__init__() super(AttentiveResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels # gc: growth channel, i.e. intermediate channels
self.conv1 = arch_util.DynamicConv2d(nf, gc, 3, 1, 1, bias=bias, num_convs=num_convs, self.conv1 = arch_util.DynamicConv2d(nf, gc, 3, 1, 1, num_convs=num_convs,
initial_temperature=init_temperature) initial_temperature=init_temperature)
self.conv2 = arch_util.DynamicConv2d(nf + gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs, self.conv2 = arch_util.DynamicConv2d(nf + gc, gc, 3, 1, 1, num_convs=num_convs,
initial_temperature=init_temperature) initial_temperature=init_temperature)
self.conv3 = arch_util.DynamicConv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs, self.conv3 = arch_util.DynamicConv2d(nf + 2 * gc, gc, 3, 1, 1, num_convs=num_convs,
initial_temperature=init_temperature) initial_temperature=init_temperature)
self.conv4 = arch_util.DynamicConv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs, self.conv4 = arch_util.DynamicConv2d(nf + 3 * gc, gc, 3, 1, 1, num_convs=num_convs,
initial_temperature=init_temperature) initial_temperature=init_temperature)
self.conv5 = arch_util.DynamicConv2d(nf + 4 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs, self.conv5 = arch_util.DynamicConv2d(nf + 4 * gc, nf, 3, 1, 1, num_convs=num_convs,
initial_temperature=init_temperature) initial_temperature=init_temperature)
# initialization # initialization

View File

@ -44,7 +44,7 @@ class DynamicConv2d(nn.Module):
# Requirements: input filter count is even, and there are more filters than there are sequences to attend to. # Requirements: input filter count is even, and there are more filters than there are sequences to attend to.
assert nf_in_per_conv % 2 == 0 assert nf_in_per_conv % 2 == 0
assert nf_in_per_conv / 2 > num_convs assert nf_in_per_conv / 2 >= num_convs
self.nf = nf_out_per_conv self.nf = nf_out_per_conv
self.num_convs = num_convs self.num_convs = num_convs
@ -75,7 +75,9 @@ class DynamicConv2d(nn.Module):
# conv_attention shape: (batch, width, height, sequences) # conv_attention shape: (batch, width, height, sequences)
# We want to format them so that we can matmul them together to produce: # We want to format them so that we can matmul them together to produce:
# desired shape: (batch, width, height, filters) # desired shape: (batch, width, height, filters)
attention_result = torch.einsum("...ij,...j->...i", [conv_outputs, conv_attention]) # Note: conv_attention will generally be cast to float32 regardless of the input type, so cast conv_outputs to
# float32 as well to match it.
attention_result = torch.einsum("...ij,...j->...i", [conv_outputs.to(dtype=torch.float32), conv_attention])
# Remember to shift the filters back into the expected slot. # Remember to shift the filters back into the expected slot.
if output_attention_weights: if output_attention_weights:
@ -175,7 +177,7 @@ from torch.utils.tensorboard import SummaryWriter
def test_dynamic_conv(): def test_dynamic_conv():
writer = SummaryWriter() writer = SummaryWriter()
dataset = datasets.ImageFolder("E:\\data\\cifar-100-python\\images\\train", transforms.Compose([ dataset = datasets.ImageFolder("C:\\data\\cifar-100-python\\images\\train", transforms.Compose([
transforms.Resize(32, 32), transforms.Resize(32, 32),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ToTensor(), transforms.ToTensor(),

View File

@ -9,6 +9,7 @@ import models.archs.HighToLowResNet as HighToLowResNet
import models.archs.ResGen_arch as ResGen_arch import models.archs.ResGen_arch as ResGen_arch
import models.archs.biggan_gen_arch as biggan_arch import models.archs.biggan_gen_arch as biggan_arch
import models.archs.feature_arch as feature_arch import models.archs.feature_arch as feature_arch
import functools
# Generator # Generator
def define_G(opt, net_key='network_G'): def define_G(opt, net_key='network_G'):
@ -31,6 +32,10 @@ def define_G(opt, net_key='network_G'):
elif which_model == 'AssistedRRDBNet': elif which_model == 'AssistedRRDBNet':
netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale) nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)
elif which_model == 'AttentiveRRDBNet':
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale,
rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc']))
elif which_model == 'ResGen': elif which_model == 'ResGen':
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'], netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf']) upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])