forked from mrq/DL-Art-School
68 lines
2.7 KiB
Python
68 lines
2.7 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torchvision
|
|
|
|
# Utilizes pretrained torchvision modules for feature extraction
|
|
|
|
class VGGFeatureExtractor(nn.Module):
|
|
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
|
|
device=torch.device('cpu')):
|
|
super(VGGFeatureExtractor, self).__init__()
|
|
self.use_input_norm = use_input_norm
|
|
if use_bn:
|
|
model = torchvision.models.vgg19_bn(pretrained=True)
|
|
else:
|
|
model = torchvision.models.vgg19(pretrained=True)
|
|
if self.use_input_norm:
|
|
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
|
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
|
|
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
|
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
|
self.register_buffer('mean', mean)
|
|
self.register_buffer('std', std)
|
|
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
|
# No need to BP to variable
|
|
for k, v in self.features.named_parameters():
|
|
v.requires_grad = False
|
|
|
|
def forward(self, x):
|
|
# Assume input range is [0, 1]
|
|
if self.use_input_norm:
|
|
x = (x - self.mean) / self.std
|
|
output = self.features(x)
|
|
return output
|
|
|
|
|
|
class WideResnetFeatureExtractor(nn.Module):
|
|
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
|
|
print("Using wide resnet extractor.")
|
|
super(WideResnetFeatureExtractor, self).__init__()
|
|
self.use_input_norm = use_input_norm
|
|
self.model = torchvision.models.wide_resnet50_2(pretrained=True)
|
|
if self.use_input_norm:
|
|
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
|
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
|
|
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
|
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
|
self.register_buffer('mean', mean)
|
|
self.register_buffer('std', std)
|
|
# No need to BP to variable
|
|
for p in self.model.parameters():
|
|
p.requires_grad = False
|
|
|
|
def forward(self, x):
|
|
# Assume input range is [0, 1]
|
|
if self.use_input_norm:
|
|
x = (x - self.mean) / self.std
|
|
x = self.model.conv1(x)
|
|
x = self.model.bn1(x)
|
|
x = self.model.relu(x)
|
|
x = self.model.maxpool(x)
|
|
x = self.model.layer1(x)
|
|
x = self.model.layer2(x)
|
|
x = self.model.layer3(x)
|
|
return x
|
|
|
|
|
|
w = WideResnetFeatureExtractor()
|
|
w.forward(torch.randn(3,64,64)) |