CIFAR stuff

- Extract coarse labels for the CIFAR dataset
- Add simple resnet that branches lower layers based on coarse labels
- Some other cleanup
This commit is contained in:
James Betker 2021-06-05 14:16:02 -06:00
parent 80d4404367
commit fb405d9ef1
10 changed files with 366 additions and 32 deletions

176
codes/data/cifar.py Normal file
View File

@ -0,0 +1,176 @@
# A copy of the cifar dataset from torch which also returns coarse labels.
from PIL import Image
import os
import os.path
import numpy as np
import pickle
from typing import Any, Callable, Optional, Tuple
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'cifar-10-batches-py'
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]
meta = {
'filename': 'batches.meta',
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(CIFAR10, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train # training set or test set
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
if self.train:
downloaded_list = self.train_list
else:
downloaded_list = self.test_list
self.data: Any = []
self.targets = []
self.coarse_targets = []
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels'])
self.coarse_targets.extend(entry['coarse_labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
self._load_meta()
def _load_meta(self) -> None:
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
if not check_integrity(path, self.meta['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' +
' You can use download=True to download it')
with open(path, 'rb') as infile:
data = pickle.load(infile, encoding='latin1')
self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if len(self.coarse_targets) > 0:
return img, target, self.coarse_targets[index]
return img, target
def __len__(self) -> int:
return len(self.data)
def _check_integrity(self) -> bool:
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5):
return False
return True
def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")
class CIFAR100(CIFAR10):
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
This is a subclass of the `CIFAR10` Dataset.
"""
base_folder = 'cifar-100-python'
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],
]
test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]
meta = {
'filename': 'meta',
'key': 'fine_label_names',
'md5': '7973b15100ade9c7d40fb424638fde48',
}

View File

@ -4,6 +4,7 @@ import torchvision.transforms as T
from torchvision import datasets from torchvision import datasets
# Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer. # Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer.
from data.cifar import CIFAR100, CIFAR10
from utils.util import opt_get from utils.util import opt_get
@ -12,8 +13,8 @@ class TorchDataset(Dataset):
DATASET_MAP = { DATASET_MAP = {
"mnist": datasets.MNIST, "mnist": datasets.MNIST,
"fmnist": datasets.FashionMNIST, "fmnist": datasets.FashionMNIST,
"cifar10": datasets.CIFAR10, "cifar10": CIFAR10,
"cifar100": datasets.CIFAR100, "cifar100": CIFAR100,
"imagenet": datasets.ImageNet, "imagenet": datasets.ImageNet,
"imagefolder": datasets.ImageFolder "imagefolder": datasets.ImageFolder
} }
@ -39,8 +40,15 @@ class TorchDataset(Dataset):
self.offset = opt_get(opt, ['offset'], 0) self.offset = opt_get(opt, ['offset'], 0)
def __getitem__(self, item): def __getitem__(self, item):
underlying_item, lbl = self.dataset[item+self.offset] item = self.dataset[item+self.offset]
return {'lq': underlying_item, 'hq': underlying_item, 'labels': lbl, if len(item) == 2:
underlying_item, lbl = item
coarselbl = None
elif len(item) == 3:
underlying_item, lbl, coarselbl = item
else:
raise NotImplementedError
return {'lq': underlying_item, 'hq': underlying_item, 'labels': lbl, 'coarse_labels': coarselbl,
'LQ_path': str(item), 'GT_path': str(item)} 'LQ_path': str(item), 'GT_path': str(item)}
def __len__(self): def __len__(self):

View File

View File

@ -0,0 +1,175 @@
"""resnet in pytorch
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
Deep Residual Learning for Image Recognition
https://arxiv.org/abs/1512.03385v1
"""
import torch
import torch.nn as nn
from trainer.networks import register_model
class BasicBlock(nn.Module):
"""Basic Block for resnet 18 and resnet 34
"""
#BasicBlock and BottleNeck block
#have different output size
#we use class attribute expansion
#to distinct
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
#residual function
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
#shortcut
self.shortcut = nn.Sequential()
#the shortcut output dimension is not the same with residual function
#use 1*1 convolution to match the dimension
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class BottleNeck(nn.Module):
"""Residual block for resnet over 50 layers
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class ResNetTail(nn.Module):
def __init__(self, block, num_block, num_classes=100):
super().__init__()
self.in_channels = 128
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
output = self.conv4_x(x)
output = self.conv5_x(output)
output = self.avg_pool(output)
output = output.view(output.size(0), -1)
output = self.fc(output)
return output
class ResNet(nn.Module):
def __init__(self, block, num_block, num_classes=100, num_tails=20):
super().__init__()
self.in_channels = 64
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
self.tails = nn.ModuleList([ResNetTail(block, num_block, num_classes) for _ in range(num_tails)])
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def forward(self, x, coarse_label):
output = self.conv1(x)
output = self.conv2_x(output)
output = self.conv3_x(output)
bs = output.shape[0]
tailouts = []
for t in self.tails:
tailouts.append(t(output))
tailouts = torch.stack(tailouts, dim=0)
return (tailouts[coarse_label] * torch.eye(n=bs).view(bs,bs,1)).sum(dim=1)
@register_model
def register_cifar_resnet18(opt_net, opt):
""" return a ResNet 18 object
"""
return ResNet(BasicBlock, [2, 2, 2, 2])
def resnet34():
""" return a ResNet 34 object
"""
return ResNet(BasicBlock, [3, 4, 6, 3])
def resnet50():
""" return a ResNet 50 object
"""
return ResNet(BottleNeck, [3, 4, 6, 3])
def resnet101():
""" return a ResNet 101 object
"""
return ResNet(BottleNeck, [3, 4, 23, 3])
def resnet152():
""" return a ResNet 152 object
"""
return ResNet(BottleNeck, [3, 8, 36, 3])
if __name__ == '__main__':
model = ResNet(BasicBlock, [2,2,2,2])
print(model(torch.randn(2,3,32,32), torch.LongTensor([4,19])).shape)

View File

@ -4,22 +4,16 @@ import shutil
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision
from PIL import Image from PIL import Image
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Resize from torchvision.transforms import ToTensor
from tqdm import tqdm from tqdm import tqdm
import numpy as np
import utils
from data.image_folder_dataset import ImageFolderDataset from data.image_folder_dataset import ImageFolderDataset
from models.resnet_with_checkpointing import resnet50 from models.classifiers.resnet_with_checkpointing import resnet50
from models.spinenet_arch import SpineNet
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved # Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
# and the distance is computed across the channel dimension. # and the distance is computed across the channel dimension.
from utils import util
from utils.kmeans import kmeans, kmeans_predict from utils.kmeans import kmeans, kmeans_predict
from utils.options import dict_to_nonedict from utils.options import dict_to_nonedict

View File

@ -1,5 +1,4 @@
import os import os
import shutil
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -7,20 +6,14 @@ import torch.nn.functional as F
import torchvision import torchvision
from PIL import Image from PIL import Image
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Resize, Normalize from torchvision.transforms import ToTensor, Normalize
from tqdm import tqdm from tqdm import tqdm
import numpy as np
import utils
from data.image_folder_dataset import ImageFolderDataset from data.image_folder_dataset import ImageFolderDataset
from models.resnet_with_checkpointing import resnet50
from models.segformer.segformer import Segformer from models.segformer.segformer import Segformer
from models.spinenet_arch import SpineNet
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved # Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
# and the distance is computed across the channel dimension. # and the distance is computed across the channel dimension.
from utils import util
from utils.kmeans import kmeans, kmeans_predict from utils.kmeans import kmeans, kmeans_predict
from utils.options import dict_to_nonedict from utils.options import dict_to_nonedict

View File

@ -1,5 +1,4 @@
import os import os
import shutil
from random import shuffle from random import shuffle
import matplotlib.cm as cm import matplotlib.cm as cm
@ -7,26 +6,15 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision import torchvision
from PIL import Image
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.models.resnet import Bottleneck from torchvision.models.resnet import Bottleneck
from torchvision.transforms import ToTensor, Resize
from tqdm import tqdm from tqdm import tqdm
import numpy as np
import utils
from data.image_folder_dataset import ImageFolderDataset from data.image_folder_dataset import ImageFolderDataset
from models.pixel_level_contrastive_learning.resnet_unet import UResNet50
from models.pixel_level_contrastive_learning.resnet_unet_2 import UResNet50_2
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3 from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
from models.resnet_with_checkpointing import resnet50
from models.spinenet_arch import SpineNet
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved # Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
# and the distance is computed across the channel dimension. # and the distance is computed across the channel dimension.
from scripts.byol.byol_spinenet_playground import find_similar_latents, create_latent_database
from utils import util
from utils.kmeans import kmeans, kmeans_predict from utils.kmeans import kmeans, kmeans_predict
from utils.options import dict_to_nonedict from utils.options import dict_to_nonedict