diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 0e9e4460..1abe2b5f 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -410,6 +410,59 @@ class BackboneEncoderNoRef(nn.Module): return patch +class BackboneSpinenetNoHead(nn.Module): + def __init__(self): + super(BackboneSpinenetNoHead, self).__init__() + self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True) + + def forward(self, x): + patch = checkpoint(self.patch_spine, x)[0] + return patch + + +class ResBlock(nn.Module): + def __init__(self, nf, downsample): + super(ResBlock, self).__init__() + nf_int = nf * 2 + nf_out = nf * 2 if downsample else nf + stride = 2 if downsample else 1 + self.c1 = ConvGnSilu(nf, nf_int, kernel_size=3, bias=False, activation=True, norm=True) + self.c2 = ConvGnSilu(nf_int, nf_int, stride=stride, kernel_size=3, bias=False, activation=True, norm=True) + self.c3 = ConvGnSilu(nf_int, nf_out, kernel_size=3, bias=False, activation=False, norm=True) + if downsample: + self.downsample = ConvGnSilu(nf, nf_out, kernel_size=1, stride=stride, bias=False, activation=False, norm=True) + else: + self.downsample = None + self.act = nn.SiLU(inplace=True) + + def forward(self, x): + identity = x + branch = self.c1(x) + branch = self.c2(branch) + branch = self.c3(branch) + + if self.downsample: + identity = self.downsample(identity) + return self.act(identity + branch) + + +class BackboneResnet(nn.Module): + def __init__(self): + super(BackboneResnet, self).__init__() + self.initial_conv = ConvGnSilu(3, 64, kernel_size=7, bias=True, activation=False, norm=False) + self.sequence = nn.Sequential( + ResBlock(64, downsample=False), + ResBlock(64, downsample=True), + ResBlock(128, downsample=False), + ResBlock(128, downsample=True), + ResBlock(256, downsample=False), + ResBlock(256, downsample=False)) + + def forward(self, x): + fea = self.initial_conv(x) + return self.sequence(fea) + + # Note to future self: # Can I do a real transformer here? Such as by having the multiplexer be able to toggle off of transformations by # their output? The embedding will be used as the "Query" to the "QueryxKey=Value" relationship. diff --git a/codes/models/networks.py b/codes/models/networks.py index 672d3c4e..e74cd96f 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -77,6 +77,8 @@ def define_G(opt, net_key='network_G', scale=None): netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet']) elif which_model == "backbone_encoder_no_ref": netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet']) + elif which_model == "backbone_resnet": + netG = SwitchedGen_arch.BackboneResnet() else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))