Fix attention conv bugs
This commit is contained in:
parent
cbedd6340a
commit
063719c5cc
|
@ -34,15 +34,15 @@ class AttentiveResidualDenseBlock_5C(ResidualDenseBlock_5C):
|
||||||
def __init__(self, nf=64, gc=32, num_convs=8, init_temperature=1):
|
def __init__(self, nf=64, gc=32, num_convs=8, init_temperature=1):
|
||||||
super(AttentiveResidualDenseBlock_5C, self).__init__()
|
super(AttentiveResidualDenseBlock_5C, self).__init__()
|
||||||
# gc: growth channel, i.e. intermediate channels
|
# gc: growth channel, i.e. intermediate channels
|
||||||
self.conv1 = arch_util.DynamicConv2d(nf, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
|
self.conv1 = arch_util.DynamicConv2d(nf, gc, 3, 1, 1, num_convs=num_convs,
|
||||||
initial_temperature=init_temperature)
|
initial_temperature=init_temperature)
|
||||||
self.conv2 = arch_util.DynamicConv2d(nf + gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
|
self.conv2 = arch_util.DynamicConv2d(nf + gc, gc, 3, 1, 1, num_convs=num_convs,
|
||||||
initial_temperature=init_temperature)
|
initial_temperature=init_temperature)
|
||||||
self.conv3 = arch_util.DynamicConv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
|
self.conv3 = arch_util.DynamicConv2d(nf + 2 * gc, gc, 3, 1, 1, num_convs=num_convs,
|
||||||
initial_temperature=init_temperature)
|
initial_temperature=init_temperature)
|
||||||
self.conv4 = arch_util.DynamicConv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
|
self.conv4 = arch_util.DynamicConv2d(nf + 3 * gc, gc, 3, 1, 1, num_convs=num_convs,
|
||||||
initial_temperature=init_temperature)
|
initial_temperature=init_temperature)
|
||||||
self.conv5 = arch_util.DynamicConv2d(nf + 4 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
|
self.conv5 = arch_util.DynamicConv2d(nf + 4 * gc, nf, 3, 1, 1, num_convs=num_convs,
|
||||||
initial_temperature=init_temperature)
|
initial_temperature=init_temperature)
|
||||||
|
|
||||||
# initialization
|
# initialization
|
||||||
|
|
|
@ -44,7 +44,7 @@ class DynamicConv2d(nn.Module):
|
||||||
|
|
||||||
# Requirements: input filter count is even, and there are more filters than there are sequences to attend to.
|
# Requirements: input filter count is even, and there are more filters than there are sequences to attend to.
|
||||||
assert nf_in_per_conv % 2 == 0
|
assert nf_in_per_conv % 2 == 0
|
||||||
assert nf_in_per_conv / 2 > num_convs
|
assert nf_in_per_conv / 2 >= num_convs
|
||||||
|
|
||||||
self.nf = nf_out_per_conv
|
self.nf = nf_out_per_conv
|
||||||
self.num_convs = num_convs
|
self.num_convs = num_convs
|
||||||
|
@ -75,7 +75,9 @@ class DynamicConv2d(nn.Module):
|
||||||
# conv_attention shape: (batch, width, height, sequences)
|
# conv_attention shape: (batch, width, height, sequences)
|
||||||
# We want to format them so that we can matmul them together to produce:
|
# We want to format them so that we can matmul them together to produce:
|
||||||
# desired shape: (batch, width, height, filters)
|
# desired shape: (batch, width, height, filters)
|
||||||
attention_result = torch.einsum("...ij,...j->...i", [conv_outputs, conv_attention])
|
# Note: conv_attention will generally be cast to float32 regardless of the input type, so cast conv_outputs to
|
||||||
|
# float32 as well to match it.
|
||||||
|
attention_result = torch.einsum("...ij,...j->...i", [conv_outputs.to(dtype=torch.float32), conv_attention])
|
||||||
|
|
||||||
# Remember to shift the filters back into the expected slot.
|
# Remember to shift the filters back into the expected slot.
|
||||||
if output_attention_weights:
|
if output_attention_weights:
|
||||||
|
@ -175,7 +177,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
def test_dynamic_conv():
|
def test_dynamic_conv():
|
||||||
writer = SummaryWriter()
|
writer = SummaryWriter()
|
||||||
dataset = datasets.ImageFolder("E:\\data\\cifar-100-python\\images\\train", transforms.Compose([
|
dataset = datasets.ImageFolder("C:\\data\\cifar-100-python\\images\\train", transforms.Compose([
|
||||||
transforms.Resize(32, 32),
|
transforms.Resize(32, 32),
|
||||||
transforms.RandomHorizontalFlip(),
|
transforms.RandomHorizontalFlip(),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
|
|
|
@ -9,6 +9,7 @@ import models.archs.HighToLowResNet as HighToLowResNet
|
||||||
import models.archs.ResGen_arch as ResGen_arch
|
import models.archs.ResGen_arch as ResGen_arch
|
||||||
import models.archs.biggan_gen_arch as biggan_arch
|
import models.archs.biggan_gen_arch as biggan_arch
|
||||||
import models.archs.feature_arch as feature_arch
|
import models.archs.feature_arch as feature_arch
|
||||||
|
import functools
|
||||||
|
|
||||||
# Generator
|
# Generator
|
||||||
def define_G(opt, net_key='network_G'):
|
def define_G(opt, net_key='network_G'):
|
||||||
|
@ -31,6 +32,10 @@ def define_G(opt, net_key='network_G'):
|
||||||
elif which_model == 'AssistedRRDBNet':
|
elif which_model == 'AssistedRRDBNet':
|
||||||
netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)
|
||||||
|
elif which_model == 'AttentiveRRDBNet':
|
||||||
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale,
|
||||||
|
rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc']))
|
||||||
elif which_model == 'ResGen':
|
elif which_model == 'ResGen':
|
||||||
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
|
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
|
||||||
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])
|
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user