import torch import torch.nn as nn import torchvision # Utilizes pretrained torchvision modules for feature extraction class VGGFeatureExtractor(nn.Module): def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, device=torch.device('cpu')): super(VGGFeatureExtractor, self).__init__() self.use_input_norm = use_input_norm if use_bn: model = torchvision.models.vgg19_bn(pretrained=True) else: model = torchvision.models.vgg19(pretrained=True) if self.use_input_norm: mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] self.register_buffer('mean', mean) self.register_buffer('std', std) self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) # No need to BP to variable for k, v in self.features.named_parameters(): v.requires_grad = False def forward(self, x): # Assume input range is [0, 1] if self.use_input_norm: x = (x - self.mean) / self.std output = self.features(x) return output class WideResnetFeatureExtractor(nn.Module): def __init__(self, use_input_norm=True, device=torch.device('cpu')): print("Using wide resnet extractor.") super(WideResnetFeatureExtractor, self).__init__() self.use_input_norm = use_input_norm self.model = torchvision.models.wide_resnet50_2(pretrained=True) if self.use_input_norm: mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] self.register_buffer('mean', mean) self.register_buffer('std', std) # No need to BP to variable for p in self.model.parameters(): p.requires_grad = False def forward(self, x): # Assume input range is [0, 1] if self.use_input_norm: x = (x - self.mean) / self.std x = self.model.conv1(x) x = self.model.bn1(x) x = self.model.relu(x) x = self.model.maxpool(x) x = self.model.layer1(x) x = self.model.layer2(x) x = self.model.layer3(x) return x w = WideResnetFeatureExtractor() w.forward(torch.randn(3,64,64))