Some new backbones
This commit is contained in:
parent
9429544a60
commit
419f77ec19
|
@ -410,6 +410,59 @@ class BackboneEncoderNoRef(nn.Module):
|
|||
return patch
|
||||
|
||||
|
||||
class BackboneSpinenetNoHead(nn.Module):
|
||||
def __init__(self):
|
||||
super(BackboneSpinenetNoHead, self).__init__()
|
||||
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True)
|
||||
|
||||
def forward(self, x):
|
||||
patch = checkpoint(self.patch_spine, x)[0]
|
||||
return patch
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, nf, downsample):
|
||||
super(ResBlock, self).__init__()
|
||||
nf_int = nf * 2
|
||||
nf_out = nf * 2 if downsample else nf
|
||||
stride = 2 if downsample else 1
|
||||
self.c1 = ConvGnSilu(nf, nf_int, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
self.c2 = ConvGnSilu(nf_int, nf_int, stride=stride, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
self.c3 = ConvGnSilu(nf_int, nf_out, kernel_size=3, bias=False, activation=False, norm=True)
|
||||
if downsample:
|
||||
self.downsample = ConvGnSilu(nf, nf_out, kernel_size=1, stride=stride, bias=False, activation=False, norm=True)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.act = nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
branch = self.c1(x)
|
||||
branch = self.c2(branch)
|
||||
branch = self.c3(branch)
|
||||
|
||||
if self.downsample:
|
||||
identity = self.downsample(identity)
|
||||
return self.act(identity + branch)
|
||||
|
||||
|
||||
class BackboneResnet(nn.Module):
|
||||
def __init__(self):
|
||||
super(BackboneResnet, self).__init__()
|
||||
self.initial_conv = ConvGnSilu(3, 64, kernel_size=7, bias=True, activation=False, norm=False)
|
||||
self.sequence = nn.Sequential(
|
||||
ResBlock(64, downsample=False),
|
||||
ResBlock(64, downsample=True),
|
||||
ResBlock(128, downsample=False),
|
||||
ResBlock(128, downsample=True),
|
||||
ResBlock(256, downsample=False),
|
||||
ResBlock(256, downsample=False))
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.initial_conv(x)
|
||||
return self.sequence(fea)
|
||||
|
||||
|
||||
# Note to future self:
|
||||
# Can I do a real transformer here? Such as by having the multiplexer be able to toggle off of transformations by
|
||||
# their output? The embedding will be used as the "Query" to the "QueryxKey=Value" relationship.
|
||||
|
|
|
@ -77,6 +77,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
||||
elif which_model == "backbone_encoder_no_ref":
|
||||
netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet'])
|
||||
elif which_model == "backbone_resnet":
|
||||
netG = SwitchedGen_arch.BackboneResnet()
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user