forked from mrq/DL-Art-School
Fix attention conv bugs
This commit is contained in:
parent
cbedd6340a
commit
063719c5cc
|
@ -34,15 +34,15 @@ class AttentiveResidualDenseBlock_5C(ResidualDenseBlock_5C):
|
|||
def __init__(self, nf=64, gc=32, num_convs=8, init_temperature=1):
|
||||
super(AttentiveResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = arch_util.DynamicConv2d(nf, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
|
||||
self.conv1 = arch_util.DynamicConv2d(nf, gc, 3, 1, 1, num_convs=num_convs,
|
||||
initial_temperature=init_temperature)
|
||||
self.conv2 = arch_util.DynamicConv2d(nf + gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
|
||||
self.conv2 = arch_util.DynamicConv2d(nf + gc, gc, 3, 1, 1, num_convs=num_convs,
|
||||
initial_temperature=init_temperature)
|
||||
self.conv3 = arch_util.DynamicConv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
|
||||
self.conv3 = arch_util.DynamicConv2d(nf + 2 * gc, gc, 3, 1, 1, num_convs=num_convs,
|
||||
initial_temperature=init_temperature)
|
||||
self.conv4 = arch_util.DynamicConv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
|
||||
self.conv4 = arch_util.DynamicConv2d(nf + 3 * gc, gc, 3, 1, 1, num_convs=num_convs,
|
||||
initial_temperature=init_temperature)
|
||||
self.conv5 = arch_util.DynamicConv2d(nf + 4 * gc, gc, 3, 1, 1, bias=bias, num_convs=num_convs,
|
||||
self.conv5 = arch_util.DynamicConv2d(nf + 4 * gc, nf, 3, 1, 1, num_convs=num_convs,
|
||||
initial_temperature=init_temperature)
|
||||
|
||||
# initialization
|
||||
|
|
|
@ -44,7 +44,7 @@ class DynamicConv2d(nn.Module):
|
|||
|
||||
# Requirements: input filter count is even, and there are more filters than there are sequences to attend to.
|
||||
assert nf_in_per_conv % 2 == 0
|
||||
assert nf_in_per_conv / 2 > num_convs
|
||||
assert nf_in_per_conv / 2 >= num_convs
|
||||
|
||||
self.nf = nf_out_per_conv
|
||||
self.num_convs = num_convs
|
||||
|
@ -75,7 +75,9 @@ class DynamicConv2d(nn.Module):
|
|||
# conv_attention shape: (batch, width, height, sequences)
|
||||
# We want to format them so that we can matmul them together to produce:
|
||||
# desired shape: (batch, width, height, filters)
|
||||
attention_result = torch.einsum("...ij,...j->...i", [conv_outputs, conv_attention])
|
||||
# Note: conv_attention will generally be cast to float32 regardless of the input type, so cast conv_outputs to
|
||||
# float32 as well to match it.
|
||||
attention_result = torch.einsum("...ij,...j->...i", [conv_outputs.to(dtype=torch.float32), conv_attention])
|
||||
|
||||
# Remember to shift the filters back into the expected slot.
|
||||
if output_attention_weights:
|
||||
|
@ -175,7 +177,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||
|
||||
def test_dynamic_conv():
|
||||
writer = SummaryWriter()
|
||||
dataset = datasets.ImageFolder("E:\\data\\cifar-100-python\\images\\train", transforms.Compose([
|
||||
dataset = datasets.ImageFolder("C:\\data\\cifar-100-python\\images\\train", transforms.Compose([
|
||||
transforms.Resize(32, 32),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
|
|
|
@ -9,6 +9,7 @@ import models.archs.HighToLowResNet as HighToLowResNet
|
|||
import models.archs.ResGen_arch as ResGen_arch
|
||||
import models.archs.biggan_gen_arch as biggan_arch
|
||||
import models.archs.feature_arch as feature_arch
|
||||
import functools
|
||||
|
||||
# Generator
|
||||
def define_G(opt, net_key='network_G'):
|
||||
|
@ -31,6 +32,10 @@ def define_G(opt, net_key='network_G'):
|
|||
elif which_model == 'AssistedRRDBNet':
|
||||
netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)
|
||||
elif which_model == 'AttentiveRRDBNet':
|
||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale,
|
||||
rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc']))
|
||||
elif which_model == 'ResGen':
|
||||
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
|
||||
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])
|
||||
|
|
Loading…
Reference in New Issue
Block a user