Some new backbones
This commit is contained in:
parent
9429544a60
commit
419f77ec19
|
@ -410,6 +410,59 @@ class BackboneEncoderNoRef(nn.Module):
|
||||||
return patch
|
return patch
|
||||||
|
|
||||||
|
|
||||||
|
class BackboneSpinenetNoHead(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(BackboneSpinenetNoHead, self).__init__()
|
||||||
|
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
patch = checkpoint(self.patch_spine, x)[0]
|
||||||
|
return patch
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, nf, downsample):
|
||||||
|
super(ResBlock, self).__init__()
|
||||||
|
nf_int = nf * 2
|
||||||
|
nf_out = nf * 2 if downsample else nf
|
||||||
|
stride = 2 if downsample else 1
|
||||||
|
self.c1 = ConvGnSilu(nf, nf_int, kernel_size=3, bias=False, activation=True, norm=True)
|
||||||
|
self.c2 = ConvGnSilu(nf_int, nf_int, stride=stride, kernel_size=3, bias=False, activation=True, norm=True)
|
||||||
|
self.c3 = ConvGnSilu(nf_int, nf_out, kernel_size=3, bias=False, activation=False, norm=True)
|
||||||
|
if downsample:
|
||||||
|
self.downsample = ConvGnSilu(nf, nf_out, kernel_size=1, stride=stride, bias=False, activation=False, norm=True)
|
||||||
|
else:
|
||||||
|
self.downsample = None
|
||||||
|
self.act = nn.SiLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
branch = self.c1(x)
|
||||||
|
branch = self.c2(branch)
|
||||||
|
branch = self.c3(branch)
|
||||||
|
|
||||||
|
if self.downsample:
|
||||||
|
identity = self.downsample(identity)
|
||||||
|
return self.act(identity + branch)
|
||||||
|
|
||||||
|
|
||||||
|
class BackboneResnet(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(BackboneResnet, self).__init__()
|
||||||
|
self.initial_conv = ConvGnSilu(3, 64, kernel_size=7, bias=True, activation=False, norm=False)
|
||||||
|
self.sequence = nn.Sequential(
|
||||||
|
ResBlock(64, downsample=False),
|
||||||
|
ResBlock(64, downsample=True),
|
||||||
|
ResBlock(128, downsample=False),
|
||||||
|
ResBlock(128, downsample=True),
|
||||||
|
ResBlock(256, downsample=False),
|
||||||
|
ResBlock(256, downsample=False))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fea = self.initial_conv(x)
|
||||||
|
return self.sequence(fea)
|
||||||
|
|
||||||
|
|
||||||
# Note to future self:
|
# Note to future self:
|
||||||
# Can I do a real transformer here? Such as by having the multiplexer be able to toggle off of transformations by
|
# Can I do a real transformer here? Such as by having the multiplexer be able to toggle off of transformations by
|
||||||
# their output? The embedding will be used as the "Query" to the "QueryxKey=Value" relationship.
|
# their output? The embedding will be used as the "Query" to the "QueryxKey=Value" relationship.
|
||||||
|
|
|
@ -77,6 +77,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
||||||
elif which_model == "backbone_encoder_no_ref":
|
elif which_model == "backbone_encoder_no_ref":
|
||||||
netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet'])
|
netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet'])
|
||||||
|
elif which_model == "backbone_resnet":
|
||||||
|
netG = SwitchedGen_arch.BackboneResnet()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user