forked from mrq/DL-Art-School
CIFAR stuff
- Extract coarse labels for the CIFAR dataset - Add simple resnet that branches lower layers based on coarse labels - Some other cleanup
This commit is contained in:
parent
80d4404367
commit
fb405d9ef1
176
codes/data/cifar.py
Normal file
176
codes/data/cifar.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
# A copy of the cifar dataset from torch which also returns coarse labels.
|
||||
|
||||
from PIL import Image
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import pickle
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
|
||||
|
||||
|
||||
class CIFAR10(VisionDataset):
|
||||
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset where directory
|
||||
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
|
||||
train (bool, optional): If True, creates dataset from training set, otherwise
|
||||
creates from test set.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
|
||||
"""
|
||||
base_folder = 'cifar-10-batches-py'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||
filename = "cifar-10-python.tar.gz"
|
||||
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
|
||||
train_list = [
|
||||
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
||||
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
|
||||
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
|
||||
['data_batch_4', '634d18415352ddfa80567beed471001a'],
|
||||
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
|
||||
]
|
||||
|
||||
test_list = [
|
||||
['test_batch', '40351d587109b95175f43aff81a1287e'],
|
||||
]
|
||||
meta = {
|
||||
'filename': 'batches.meta',
|
||||
'key': 'label_names',
|
||||
'md5': '5ff9c542aee3614f3951f8cda6e48888',
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
train: bool = True,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
download: bool = False,
|
||||
) -> None:
|
||||
|
||||
super(CIFAR10, self).__init__(root, transform=transform,
|
||||
target_transform=target_transform)
|
||||
|
||||
self.train = train # training set or test set
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError('Dataset not found or corrupted.' +
|
||||
' You can use download=True to download it')
|
||||
|
||||
if self.train:
|
||||
downloaded_list = self.train_list
|
||||
else:
|
||||
downloaded_list = self.test_list
|
||||
|
||||
self.data: Any = []
|
||||
self.targets = []
|
||||
self.coarse_targets = []
|
||||
|
||||
# now load the picked numpy arrays
|
||||
for file_name, checksum in downloaded_list:
|
||||
file_path = os.path.join(self.root, self.base_folder, file_name)
|
||||
with open(file_path, 'rb') as f:
|
||||
entry = pickle.load(f, encoding='latin1')
|
||||
self.data.append(entry['data'])
|
||||
if 'labels' in entry:
|
||||
self.targets.extend(entry['labels'])
|
||||
else:
|
||||
self.targets.extend(entry['fine_labels'])
|
||||
self.coarse_targets.extend(entry['coarse_labels'])
|
||||
|
||||
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
|
||||
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
||||
|
||||
self._load_meta()
|
||||
|
||||
def _load_meta(self) -> None:
|
||||
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
|
||||
if not check_integrity(path, self.meta['md5']):
|
||||
raise RuntimeError('Dataset metadata file not found or corrupted.' +
|
||||
' You can use download=True to download it')
|
||||
with open(path, 'rb') as infile:
|
||||
data = pickle.load(infile, encoding='latin1')
|
||||
self.classes = data[self.meta['key']]
|
||||
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, target) where target is index of the target class.
|
||||
"""
|
||||
img, target = self.data[index], self.targets[index]
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
if len(self.coarse_targets) > 0:
|
||||
return img, target, self.coarse_targets[index]
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
root = self.root
|
||||
for fentry in (self.train_list + self.test_list):
|
||||
filename, md5 = fentry[0], fentry[1]
|
||||
fpath = os.path.join(root, self.base_folder, filename)
|
||||
if not check_integrity(fpath, md5):
|
||||
return False
|
||||
return True
|
||||
|
||||
def download(self) -> None:
|
||||
if self._check_integrity():
|
||||
print('Files already downloaded and verified')
|
||||
return
|
||||
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "Split: {}".format("Train" if self.train is True else "Test")
|
||||
|
||||
|
||||
class CIFAR100(CIFAR10):
|
||||
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
||||
|
||||
This is a subclass of the `CIFAR10` Dataset.
|
||||
"""
|
||||
base_folder = 'cifar-100-python'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
||||
filename = "cifar-100-python.tar.gz"
|
||||
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
||||
train_list = [
|
||||
['train', '16019d7e3df5f24257cddd939b257f8d'],
|
||||
]
|
||||
|
||||
test_list = [
|
||||
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
|
||||
]
|
||||
meta = {
|
||||
'filename': 'meta',
|
||||
'key': 'fine_label_names',
|
||||
'md5': '7973b15100ade9c7d40fb424638fde48',
|
||||
}
|
|
@ -4,6 +4,7 @@ import torchvision.transforms as T
|
|||
from torchvision import datasets
|
||||
|
||||
# Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer.
|
||||
from data.cifar import CIFAR100, CIFAR10
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
@ -12,8 +13,8 @@ class TorchDataset(Dataset):
|
|||
DATASET_MAP = {
|
||||
"mnist": datasets.MNIST,
|
||||
"fmnist": datasets.FashionMNIST,
|
||||
"cifar10": datasets.CIFAR10,
|
||||
"cifar100": datasets.CIFAR100,
|
||||
"cifar10": CIFAR10,
|
||||
"cifar100": CIFAR100,
|
||||
"imagenet": datasets.ImageNet,
|
||||
"imagefolder": datasets.ImageFolder
|
||||
}
|
||||
|
@ -39,8 +40,15 @@ class TorchDataset(Dataset):
|
|||
self.offset = opt_get(opt, ['offset'], 0)
|
||||
|
||||
def __getitem__(self, item):
|
||||
underlying_item, lbl = self.dataset[item+self.offset]
|
||||
return {'lq': underlying_item, 'hq': underlying_item, 'labels': lbl,
|
||||
item = self.dataset[item+self.offset]
|
||||
if len(item) == 2:
|
||||
underlying_item, lbl = item
|
||||
coarselbl = None
|
||||
elif len(item) == 3:
|
||||
underlying_item, lbl, coarselbl = item
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return {'lq': underlying_item, 'hq': underlying_item, 'labels': lbl, 'coarse_labels': coarselbl,
|
||||
'LQ_path': str(item), 'GT_path': str(item)}
|
||||
|
||||
def __len__(self):
|
||||
|
|
0
codes/models/classifiers/__init__.py
Normal file
0
codes/models/classifiers/__init__.py
Normal file
175
codes/models/classifiers/cifar_resnet_branched.py
Normal file
175
codes/models/classifiers/cifar_resnet_branched.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
"""resnet in pytorch
|
||||
|
||||
|
||||
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
|
||||
|
||||
Deep Residual Learning for Image Recognition
|
||||
https://arxiv.org/abs/1512.03385v1
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from trainer.networks import register_model
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
"""Basic Block for resnet 18 and resnet 34
|
||||
|
||||
"""
|
||||
|
||||
#BasicBlock and BottleNeck block
|
||||
#have different output size
|
||||
#we use class attribute expansion
|
||||
#to distinct
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
|
||||
#residual function
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
#shortcut
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
#the shortcut output dimension is not the same with residual function
|
||||
#use 1*1 convolution to match the dimension
|
||||
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
class BottleNeck(nn.Module):
|
||||
"""Residual block for resnet over 50 layers
|
||||
|
||||
"""
|
||||
expansion = 4
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
|
||||
)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
|
||||
class ResNetTail(nn.Module):
|
||||
def __init__(self, block, num_block, num_classes=100):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = 128
|
||||
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
|
||||
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_channels, out_channels, stride))
|
||||
self.in_channels = out_channels * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.conv4_x(x)
|
||||
output = self.conv5_x(output)
|
||||
output = self.avg_pool(output)
|
||||
output = output.view(output.size(0), -1)
|
||||
output = self.fc(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, num_block, num_classes=100, num_tails=20):
|
||||
super().__init__()
|
||||
self.in_channels = 64
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
|
||||
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
|
||||
self.tails = nn.ModuleList([ResNetTail(block, num_block, num_classes) for _ in range(num_tails)])
|
||||
|
||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_channels, out_channels, stride))
|
||||
self.in_channels = out_channels * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x, coarse_label):
|
||||
output = self.conv1(x)
|
||||
output = self.conv2_x(output)
|
||||
output = self.conv3_x(output)
|
||||
bs = output.shape[0]
|
||||
tailouts = []
|
||||
for t in self.tails:
|
||||
tailouts.append(t(output))
|
||||
tailouts = torch.stack(tailouts, dim=0)
|
||||
return (tailouts[coarse_label] * torch.eye(n=bs).view(bs,bs,1)).sum(dim=1)
|
||||
|
||||
@register_model
|
||||
def register_cifar_resnet18(opt_net, opt):
|
||||
""" return a ResNet 18 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2])
|
||||
|
||||
def resnet34():
|
||||
""" return a ResNet 34 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3])
|
||||
|
||||
def resnet50():
|
||||
""" return a ResNet 50 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 6, 3])
|
||||
|
||||
def resnet101():
|
||||
""" return a ResNet 101 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 23, 3])
|
||||
|
||||
def resnet152():
|
||||
""" return a ResNet 152 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 8, 36, 3])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = ResNet(BasicBlock, [2,2,2,2])
|
||||
print(model(torch.randn(2,3,32,32), torch.LongTensor([4,19])).shape)
|
||||
|
|
@ -4,22 +4,16 @@ import shutil
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import ToTensor, Resize
|
||||
from torchvision.transforms import ToTensor
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
import utils
|
||||
from data.image_folder_dataset import ImageFolderDataset
|
||||
from models.resnet_with_checkpointing import resnet50
|
||||
from models.spinenet_arch import SpineNet
|
||||
|
||||
from models.classifiers.resnet_with_checkpointing import resnet50
|
||||
|
||||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||
# and the distance is computed across the channel dimension.
|
||||
from utils import util
|
||||
from utils.kmeans import kmeans, kmeans_predict
|
||||
from utils.options import dict_to_nonedict
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -7,20 +6,14 @@ import torch.nn.functional as F
|
|||
import torchvision
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import ToTensor, Resize, Normalize
|
||||
from torchvision.transforms import ToTensor, Normalize
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
import utils
|
||||
from data.image_folder_dataset import ImageFolderDataset
|
||||
from models.resnet_with_checkpointing import resnet50
|
||||
from models.segformer.segformer import Segformer
|
||||
from models.spinenet_arch import SpineNet
|
||||
|
||||
|
||||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||
# and the distance is computed across the channel dimension.
|
||||
from utils import util
|
||||
from utils.kmeans import kmeans, kmeans_predict
|
||||
from utils.options import dict_to_nonedict
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import shutil
|
||||
from random import shuffle
|
||||
|
||||
import matplotlib.cm as cm
|
||||
|
@ -7,26 +6,15 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.models.resnet import Bottleneck
|
||||
from torchvision.transforms import ToTensor, Resize
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
import utils
|
||||
from data.image_folder_dataset import ImageFolderDataset
|
||||
from models.pixel_level_contrastive_learning.resnet_unet import UResNet50
|
||||
from models.pixel_level_contrastive_learning.resnet_unet_2 import UResNet50_2
|
||||
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
|
||||
from models.resnet_with_checkpointing import resnet50
|
||||
from models.spinenet_arch import SpineNet
|
||||
|
||||
|
||||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||
# and the distance is computed across the channel dimension.
|
||||
from scripts.byol.byol_spinenet_playground import find_similar_latents, create_latent_database
|
||||
from utils import util
|
||||
from utils.kmeans import kmeans, kmeans_predict
|
||||
from utils.options import dict_to_nonedict
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user