Separate feature extractors out, add resnet feature extractor
This commit is contained in:
parent
156cee240a
commit
57682ebee3
|
@ -62,31 +62,3 @@ class Discriminator_VGG_128(nn.Module):
|
|||
out = self.linear2(fea)
|
||||
return out
|
||||
|
||||
|
||||
class VGGFeatureExtractor(nn.Module):
|
||||
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
|
||||
device=torch.device('cpu')):
|
||||
super(VGGFeatureExtractor, self).__init__()
|
||||
self.use_input_norm = use_input_norm
|
||||
if use_bn:
|
||||
model = torchvision.models.vgg19_bn(pretrained=True)
|
||||
else:
|
||||
model = torchvision.models.vgg19(pretrained=True)
|
||||
if self.use_input_norm:
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
||||
self.register_buffer('mean', mean)
|
||||
self.register_buffer('std', std)
|
||||
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
||||
# No need to BP to variable
|
||||
for k, v in self.features.named_parameters():
|
||||
v.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
# Assume input range is [0, 1]
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
output = self.features(x)
|
||||
return output
|
||||
|
|
68
codes/models/archs/feature_arch.py
Normal file
68
codes/models/archs/feature_arch.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
# Utilizes pretrained torchvision modules for feature extraction
|
||||
|
||||
class VGGFeatureExtractor(nn.Module):
|
||||
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
|
||||
device=torch.device('cpu')):
|
||||
super(VGGFeatureExtractor, self).__init__()
|
||||
self.use_input_norm = use_input_norm
|
||||
if use_bn:
|
||||
model = torchvision.models.vgg19_bn(pretrained=True)
|
||||
else:
|
||||
model = torchvision.models.vgg19(pretrained=True)
|
||||
if self.use_input_norm:
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
||||
self.register_buffer('mean', mean)
|
||||
self.register_buffer('std', std)
|
||||
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
||||
# No need to BP to variable
|
||||
for k, v in self.features.named_parameters():
|
||||
v.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
# Assume input range is [0, 1]
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
output = self.features(x)
|
||||
return output
|
||||
|
||||
|
||||
class WideResnetFeatureExtractor(nn.Module):
|
||||
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
|
||||
print("Using wide resnet extractor.")
|
||||
super(WideResnetFeatureExtractor, self).__init__()
|
||||
self.use_input_norm = use_input_norm
|
||||
self.model = torchvision.models.wide_resnet50_2(pretrained=True)
|
||||
if self.use_input_norm:
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
||||
self.register_buffer('mean', mean)
|
||||
self.register_buffer('std', std)
|
||||
# No need to BP to variable
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
# Assume input range is [0, 1]
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
x = self.model.conv1(x)
|
||||
x = self.model.bn1(x)
|
||||
x = self.model.relu(x)
|
||||
x = self.model.maxpool(x)
|
||||
x = self.model.layer1(x)
|
||||
x = self.model.layer2(x)
|
||||
x = self.model.layer3(x)
|
||||
return x
|
||||
|
||||
|
||||
w = WideResnetFeatureExtractor()
|
||||
w.forward(torch.randn(3,64,64))
|
|
@ -8,7 +8,7 @@ import models.archs.RRDBNet_arch as RRDBNet_arch
|
|||
import models.archs.HighToLowResNet as HighToLowResNet
|
||||
import models.archs.ResGen_arch as ResGen_arch
|
||||
import models.archs.biggan_gen_arch as biggan_arch
|
||||
import math
|
||||
import models.archs.feature_arch as feature_arch
|
||||
|
||||
# Generator
|
||||
def define_G(opt, net_key='network_G'):
|
||||
|
@ -83,12 +83,16 @@ def define_D(opt):
|
|||
def define_F(opt, use_bn=False):
|
||||
gpu_ids = opt['gpu_ids']
|
||||
device = torch.device('cuda' if gpu_ids else 'cpu')
|
||||
# PyTorch pretrained VGG19-54, before ReLU.
|
||||
if use_bn:
|
||||
feature_layer = 49
|
||||
else:
|
||||
feature_layer = 34
|
||||
netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
||||
use_input_norm=True, device=device)
|
||||
if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg':
|
||||
# PyTorch pretrained VGG19-54, before ReLU.
|
||||
if use_bn:
|
||||
feature_layer = 49
|
||||
else:
|
||||
feature_layer = 34
|
||||
netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
||||
use_input_norm=True, device=device)
|
||||
elif opt['train']['which_model_F'] == 'wide_resnet':
|
||||
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device)
|
||||
|
||||
netF.eval() # No need to train
|
||||
return netF
|
||||
|
|
Loading…
Reference in New Issue
Block a user