DL-Art-School/codes/models/archs/feature_arch.py
James Betker 05963157c1 Several things
- Fixes to 'after' and 'before' defs for steps (turns out they werent working)
- Feature nets take in a list of layers to extract. Not fully implemented yet.
- Fixes bugs with RAGAN
- Allows real input into generator gan to not be detached by param
2020-09-23 11:56:36 -06:00

96 lines
4.1 KiB
Python

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
# Utilizes pretrained torchvision modules for feature extraction
class VGGFeatureExtractor(nn.Module):
def __init__(self, feature_layers=[34], use_bn=False, use_input_norm=True,
device=torch.device('cpu')):
super(VGGFeatureExtractor, self).__init__()
self.use_input_norm = use_input_norm
if use_bn:
model = torchvision.models.vgg19_bn(pretrained=True)
else:
model = torchvision.models.vgg19(pretrained=True)
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.feature_layers = feature_layers
self.features = nn.Sequential(*list(model.features.children())[:(max(feature_layers) + 1)])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x, interpolate_factor=1):
if interpolate_factor > 1:
x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic')
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
class TrainableVGGFeatureExtractor(nn.Module):
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
device=torch.device('cpu')):
super(TrainableVGGFeatureExtractor, self).__init__()
self.use_input_norm = use_input_norm
if use_bn:
model = torchvision.models.vgg19_bn(pretrained=False)
else:
model = torchvision.models.vgg19(pretrained=False)
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
def forward(self, x, interpolate_factor=1):
if interpolate_factor > 1:
x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic')
# Assume input range is [0, 1]
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
class WideResnetFeatureExtractor(nn.Module):
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
print("Using wide resnet extractor.")
super(WideResnetFeatureExtractor, self).__init__()
self.use_input_norm = use_input_norm
self.model = torchvision.models.wide_resnet50_2(pretrained=True)
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
# No need to BP to variable
for p in self.model.parameters():
p.requires_grad = False
def forward(self, x):
# Assume input range is [0, 1]
if self.use_input_norm:
x = (x - self.mean) / self.std
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
return x