From 57682ebee366ee4e78bd1a1e85d6b3505304c1f3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 28 May 2020 20:26:30 -0600 Subject: [PATCH] Separate feature extractors out, add resnet feature extractor --- codes/models/archs/discriminator_vgg_arch.py | 28 -------- codes/models/archs/feature_arch.py | 68 ++++++++++++++++++++ codes/models/networks.py | 20 +++--- 3 files changed, 80 insertions(+), 36 deletions(-) create mode 100644 codes/models/archs/feature_arch.py diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 29dda1dc..f86a94fd 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -62,31 +62,3 @@ class Discriminator_VGG_128(nn.Module): out = self.linear2(fea) return out - -class VGGFeatureExtractor(nn.Module): - def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, - device=torch.device('cpu')): - super(VGGFeatureExtractor, self).__init__() - self.use_input_norm = use_input_norm - if use_bn: - model = torchvision.models.vgg19_bn(pretrained=True) - else: - model = torchvision.models.vgg19(pretrained=True) - if self.use_input_norm: - mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) - # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] - std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) - # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] - self.register_buffer('mean', mean) - self.register_buffer('std', std) - self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) - # No need to BP to variable - for k, v in self.features.named_parameters(): - v.requires_grad = False - - def forward(self, x): - # Assume input range is [0, 1] - if self.use_input_norm: - x = (x - self.mean) / self.std - output = self.features(x) - return output diff --git a/codes/models/archs/feature_arch.py b/codes/models/archs/feature_arch.py new file mode 100644 index 00000000..3cafcc0d --- /dev/null +++ b/codes/models/archs/feature_arch.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +import torchvision + +# Utilizes pretrained torchvision modules for feature extraction + +class VGGFeatureExtractor(nn.Module): + def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, + device=torch.device('cpu')): + super(VGGFeatureExtractor, self).__init__() + self.use_input_norm = use_input_norm + if use_bn: + model = torchvision.models.vgg19_bn(pretrained=True) + else: + model = torchvision.models.vgg19(pretrained=True) + if self.use_input_norm: + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] + self.register_buffer('mean', mean) + self.register_buffer('std', std) + self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) + # No need to BP to variable + for k, v in self.features.named_parameters(): + v.requires_grad = False + + def forward(self, x): + # Assume input range is [0, 1] + if self.use_input_norm: + x = (x - self.mean) / self.std + output = self.features(x) + return output + + +class WideResnetFeatureExtractor(nn.Module): + def __init__(self, use_input_norm=True, device=torch.device('cpu')): + print("Using wide resnet extractor.") + super(WideResnetFeatureExtractor, self).__init__() + self.use_input_norm = use_input_norm + self.model = torchvision.models.wide_resnet50_2(pretrained=True) + if self.use_input_norm: + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] + self.register_buffer('mean', mean) + self.register_buffer('std', std) + # No need to BP to variable + for p in self.model.parameters(): + p.requires_grad = False + + def forward(self, x): + # Assume input range is [0, 1] + if self.use_input_norm: + x = (x - self.mean) / self.std + x = self.model.conv1(x) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + x = self.model.layer1(x) + x = self.model.layer2(x) + x = self.model.layer3(x) + return x + + +w = WideResnetFeatureExtractor() +w.forward(torch.randn(3,64,64)) \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index b5d83a42..e6ec0f4a 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -8,7 +8,7 @@ import models.archs.RRDBNet_arch as RRDBNet_arch import models.archs.HighToLowResNet as HighToLowResNet import models.archs.ResGen_arch as ResGen_arch import models.archs.biggan_gen_arch as biggan_arch -import math +import models.archs.feature_arch as feature_arch # Generator def define_G(opt, net_key='network_G'): @@ -83,12 +83,16 @@ def define_D(opt): def define_F(opt, use_bn=False): gpu_ids = opt['gpu_ids'] device = torch.device('cuda' if gpu_ids else 'cpu') - # PyTorch pretrained VGG19-54, before ReLU. - if use_bn: - feature_layer = 49 - else: - feature_layer = 34 - netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, - use_input_norm=True, device=device) + if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg': + # PyTorch pretrained VGG19-54, before ReLU. + if use_bn: + feature_layer = 49 + else: + feature_layer = 34 + netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, + use_input_norm=True, device=device) + elif opt['train']['which_model_F'] == 'wide_resnet': + netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device) + netF.eval() # No need to train return netF