Some new backbones

This commit is contained in:
James Betker 2020-09-21 12:36:49 -06:00
parent 9429544a60
commit 419f77ec19
2 changed files with 55 additions and 0 deletions

View File

@ -410,6 +410,59 @@ class BackboneEncoderNoRef(nn.Module):
return patch
class BackboneSpinenetNoHead(nn.Module):
def __init__(self):
super(BackboneSpinenetNoHead, self).__init__()
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True)
def forward(self, x):
patch = checkpoint(self.patch_spine, x)[0]
return patch
class ResBlock(nn.Module):
def __init__(self, nf, downsample):
super(ResBlock, self).__init__()
nf_int = nf * 2
nf_out = nf * 2 if downsample else nf
stride = 2 if downsample else 1
self.c1 = ConvGnSilu(nf, nf_int, kernel_size=3, bias=False, activation=True, norm=True)
self.c2 = ConvGnSilu(nf_int, nf_int, stride=stride, kernel_size=3, bias=False, activation=True, norm=True)
self.c3 = ConvGnSilu(nf_int, nf_out, kernel_size=3, bias=False, activation=False, norm=True)
if downsample:
self.downsample = ConvGnSilu(nf, nf_out, kernel_size=1, stride=stride, bias=False, activation=False, norm=True)
else:
self.downsample = None
self.act = nn.SiLU(inplace=True)
def forward(self, x):
identity = x
branch = self.c1(x)
branch = self.c2(branch)
branch = self.c3(branch)
if self.downsample:
identity = self.downsample(identity)
return self.act(identity + branch)
class BackboneResnet(nn.Module):
def __init__(self):
super(BackboneResnet, self).__init__()
self.initial_conv = ConvGnSilu(3, 64, kernel_size=7, bias=True, activation=False, norm=False)
self.sequence = nn.Sequential(
ResBlock(64, downsample=False),
ResBlock(64, downsample=True),
ResBlock(128, downsample=False),
ResBlock(128, downsample=True),
ResBlock(256, downsample=False),
ResBlock(256, downsample=False))
def forward(self, x):
fea = self.initial_conv(x)
return self.sequence(fea)
# Note to future self:
# Can I do a real transformer here? Such as by having the multiplexer be able to toggle off of transformations by
# their output? The embedding will be used as the "Query" to the "QueryxKey=Value" relationship.

View File

@ -77,6 +77,8 @@ def define_G(opt, net_key='network_G', scale=None):
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
elif which_model == "backbone_encoder_no_ref":
netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet'])
elif which_model == "backbone_resnet":
netG = SwitchedGen_arch.BackboneResnet()
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))