From fb405d9ef1783d1f19fde9fa1213aa128ee285ab Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 5 Jun 2021 14:16:02 -0600 Subject: [PATCH] CIFAR stuff - Extract coarse labels for the CIFAR dataset - Add simple resnet that branches lower layers based on coarse labels - Some other cleanup --- codes/data/cifar.py | 176 ++++++++++++++++++ codes/data/torch_dataset.py | 16 +- codes/models/classifiers/__init__.py | 0 .../models/{ => classifiers}/cifar_resnet.py | 0 .../classifiers/cifar_resnet_branched.py | 175 +++++++++++++++++ .../resnet_with_checkpointing.py | 0 .../{ => classifiers}/weighted_conv_resnet.py | 0 codes/scripts/byol/byol_resnet_playground.py | 10 +- .../scripts/byol/byol_segformer_playground.py | 9 +- codes/scripts/byol/byol_uresnet_playground.py | 12 -- 10 files changed, 366 insertions(+), 32 deletions(-) create mode 100644 codes/data/cifar.py create mode 100644 codes/models/classifiers/__init__.py rename codes/models/{ => classifiers}/cifar_resnet.py (100%) create mode 100644 codes/models/classifiers/cifar_resnet_branched.py rename codes/models/{ => classifiers}/resnet_with_checkpointing.py (100%) rename codes/models/{ => classifiers}/weighted_conv_resnet.py (100%) diff --git a/codes/data/cifar.py b/codes/data/cifar.py new file mode 100644 index 00000000..c8ff7c9c --- /dev/null +++ b/codes/data/cifar.py @@ -0,0 +1,176 @@ +# A copy of the cifar dataset from torch which also returns coarse labels. + +from PIL import Image +import os +import os.path +import numpy as np +import pickle +from typing import Any, Callable, Optional, Tuple + +from torchvision.datasets import VisionDataset +from torchvision.datasets.utils import check_integrity, download_and_extract_archive + + +class CIFAR10(VisionDataset): + """`CIFAR10 `_ Dataset. + + Args: + root (string): Root directory of dataset where directory + ``cifar-10-batches-py`` exists or will be saved to if download is set to True. + train (bool, optional): If True, creates dataset from training set, otherwise + creates from test set. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + """ + base_folder = 'cifar-10-batches-py' + url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + filename = "cifar-10-python.tar.gz" + tgz_md5 = 'c58f30108f718f92721af3b95e74349a' + train_list = [ + ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], + ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], + ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], + ['data_batch_4', '634d18415352ddfa80567beed471001a'], + ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], + ] + + test_list = [ + ['test_batch', '40351d587109b95175f43aff81a1287e'], + ] + meta = { + 'filename': 'batches.meta', + 'key': 'label_names', + 'md5': '5ff9c542aee3614f3951f8cda6e48888', + } + + def __init__( + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + + super(CIFAR10, self).__init__(root, transform=transform, + target_transform=target_transform) + + self.train = train # training set or test set + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + if self.train: + downloaded_list = self.train_list + else: + downloaded_list = self.test_list + + self.data: Any = [] + self.targets = [] + self.coarse_targets = [] + + # now load the picked numpy arrays + for file_name, checksum in downloaded_list: + file_path = os.path.join(self.root, self.base_folder, file_name) + with open(file_path, 'rb') as f: + entry = pickle.load(f, encoding='latin1') + self.data.append(entry['data']) + if 'labels' in entry: + self.targets.extend(entry['labels']) + else: + self.targets.extend(entry['fine_labels']) + self.coarse_targets.extend(entry['coarse_labels']) + + self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) + self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC + + self._load_meta() + + def _load_meta(self) -> None: + path = os.path.join(self.root, self.base_folder, self.meta['filename']) + if not check_integrity(path, self.meta['md5']): + raise RuntimeError('Dataset metadata file not found or corrupted.' + + ' You can use download=True to download it') + with open(path, 'rb') as infile: + data = pickle.load(infile, encoding='latin1') + self.classes = data[self.meta['key']] + self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], self.targets[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + if len(self.coarse_targets) > 0: + return img, target, self.coarse_targets[index] + + return img, target + + def __len__(self) -> int: + return len(self.data) + + def _check_integrity(self) -> bool: + root = self.root + for fentry in (self.train_list + self.test_list): + filename, md5 = fentry[0], fentry[1] + fpath = os.path.join(root, self.base_folder, filename) + if not check_integrity(fpath, md5): + return False + return True + + def download(self) -> None: + if self._check_integrity(): + print('Files already downloaded and verified') + return + download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) + + def extra_repr(self) -> str: + return "Split: {}".format("Train" if self.train is True else "Test") + + +class CIFAR100(CIFAR10): + """`CIFAR100 `_ Dataset. + + This is a subclass of the `CIFAR10` Dataset. + """ + base_folder = 'cifar-100-python' + url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" + filename = "cifar-100-python.tar.gz" + tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' + train_list = [ + ['train', '16019d7e3df5f24257cddd939b257f8d'], + ] + + test_list = [ + ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], + ] + meta = { + 'filename': 'meta', + 'key': 'fine_label_names', + 'md5': '7973b15100ade9c7d40fb424638fde48', + } diff --git a/codes/data/torch_dataset.py b/codes/data/torch_dataset.py index 0917483c..079cf4fe 100644 --- a/codes/data/torch_dataset.py +++ b/codes/data/torch_dataset.py @@ -4,6 +4,7 @@ import torchvision.transforms as T from torchvision import datasets # Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer. +from data.cifar import CIFAR100, CIFAR10 from utils.util import opt_get @@ -12,8 +13,8 @@ class TorchDataset(Dataset): DATASET_MAP = { "mnist": datasets.MNIST, "fmnist": datasets.FashionMNIST, - "cifar10": datasets.CIFAR10, - "cifar100": datasets.CIFAR100, + "cifar10": CIFAR10, + "cifar100": CIFAR100, "imagenet": datasets.ImageNet, "imagefolder": datasets.ImageFolder } @@ -39,8 +40,15 @@ class TorchDataset(Dataset): self.offset = opt_get(opt, ['offset'], 0) def __getitem__(self, item): - underlying_item, lbl = self.dataset[item+self.offset] - return {'lq': underlying_item, 'hq': underlying_item, 'labels': lbl, + item = self.dataset[item+self.offset] + if len(item) == 2: + underlying_item, lbl = item + coarselbl = None + elif len(item) == 3: + underlying_item, lbl, coarselbl = item + else: + raise NotImplementedError + return {'lq': underlying_item, 'hq': underlying_item, 'labels': lbl, 'coarse_labels': coarselbl, 'LQ_path': str(item), 'GT_path': str(item)} def __len__(self): diff --git a/codes/models/classifiers/__init__.py b/codes/models/classifiers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/cifar_resnet.py b/codes/models/classifiers/cifar_resnet.py similarity index 100% rename from codes/models/cifar_resnet.py rename to codes/models/classifiers/cifar_resnet.py diff --git a/codes/models/classifiers/cifar_resnet_branched.py b/codes/models/classifiers/cifar_resnet_branched.py new file mode 100644 index 00000000..03c6ac0b --- /dev/null +++ b/codes/models/classifiers/cifar_resnet_branched.py @@ -0,0 +1,175 @@ +"""resnet in pytorch + + + +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. + + Deep Residual Learning for Image Recognition + https://arxiv.org/abs/1512.03385v1 +""" + +import torch +import torch.nn as nn + +from trainer.networks import register_model + + +class BasicBlock(nn.Module): + """Basic Block for resnet 18 and resnet 34 + + """ + + #BasicBlock and BottleNeck block + #have different output size + #we use class attribute expansion + #to distinct + expansion = 1 + + def __init__(self, in_channels, out_channels, stride=1): + super().__init__() + + #residual function + self.residual_function = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels * BasicBlock.expansion) + ) + + #shortcut + self.shortcut = nn.Sequential() + + #the shortcut output dimension is not the same with residual function + #use 1*1 convolution to match the dimension + if stride != 1 or in_channels != BasicBlock.expansion * out_channels: + self.shortcut = nn.Sequential( + nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_channels * BasicBlock.expansion) + ) + + def forward(self, x): + return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) + +class BottleNeck(nn.Module): + """Residual block for resnet over 50 layers + + """ + expansion = 4 + def __init__(self, in_channels, out_channels, stride=1): + super().__init__() + self.residual_function = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels * BottleNeck.expansion), + ) + + self.shortcut = nn.Sequential() + + if stride != 1 or in_channels != out_channels * BottleNeck.expansion: + self.shortcut = nn.Sequential( + nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels * BottleNeck.expansion) + ) + + def forward(self, x): + return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) + + +class ResNetTail(nn.Module): + def __init__(self, block, num_block, num_classes=100): + super().__init__() + + self.in_channels = 128 + self.conv4_x = self._make_layer(block, 256, num_block[2], 2) + self.conv5_x = self._make_layer(block, 512, num_block[3], 2) + self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, out_channels, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_channels, out_channels, stride)) + self.in_channels = out_channels * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + output = self.conv4_x(x) + output = self.conv5_x(output) + output = self.avg_pool(output) + output = output.view(output.size(0), -1) + output = self.fc(output) + + return output + + +class ResNet(nn.Module): + + def __init__(self, block, num_block, num_classes=100, num_tails=20): + super().__init__() + self.in_channels = 64 + self.conv1 = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True)) + + self.conv2_x = self._make_layer(block, 64, num_block[0], 1) + self.conv3_x = self._make_layer(block, 128, num_block[1], 2) + self.tails = nn.ModuleList([ResNetTail(block, num_block, num_classes) for _ in range(num_tails)]) + + def _make_layer(self, block, out_channels, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_channels, out_channels, stride)) + self.in_channels = out_channels * block.expansion + return nn.Sequential(*layers) + + def forward(self, x, coarse_label): + output = self.conv1(x) + output = self.conv2_x(output) + output = self.conv3_x(output) + bs = output.shape[0] + tailouts = [] + for t in self.tails: + tailouts.append(t(output)) + tailouts = torch.stack(tailouts, dim=0) + return (tailouts[coarse_label] * torch.eye(n=bs).view(bs,bs,1)).sum(dim=1) + +@register_model +def register_cifar_resnet18(opt_net, opt): + """ return a ResNet 18 object + """ + return ResNet(BasicBlock, [2, 2, 2, 2]) + +def resnet34(): + """ return a ResNet 34 object + """ + return ResNet(BasicBlock, [3, 4, 6, 3]) + +def resnet50(): + """ return a ResNet 50 object + """ + return ResNet(BottleNeck, [3, 4, 6, 3]) + +def resnet101(): + """ return a ResNet 101 object + """ + return ResNet(BottleNeck, [3, 4, 23, 3]) + +def resnet152(): + """ return a ResNet 152 object + """ + return ResNet(BottleNeck, [3, 8, 36, 3]) + + +if __name__ == '__main__': + model = ResNet(BasicBlock, [2,2,2,2]) + print(model(torch.randn(2,3,32,32), torch.LongTensor([4,19])).shape) + diff --git a/codes/models/resnet_with_checkpointing.py b/codes/models/classifiers/resnet_with_checkpointing.py similarity index 100% rename from codes/models/resnet_with_checkpointing.py rename to codes/models/classifiers/resnet_with_checkpointing.py diff --git a/codes/models/weighted_conv_resnet.py b/codes/models/classifiers/weighted_conv_resnet.py similarity index 100% rename from codes/models/weighted_conv_resnet.py rename to codes/models/classifiers/weighted_conv_resnet.py diff --git a/codes/scripts/byol/byol_resnet_playground.py b/codes/scripts/byol/byol_resnet_playground.py index 2fd49975..b5014567 100644 --- a/codes/scripts/byol/byol_resnet_playground.py +++ b/codes/scripts/byol/byol_resnet_playground.py @@ -4,22 +4,16 @@ import shutil import torch import torch.nn as nn import torch.nn.functional as F -import torchvision from PIL import Image from torch.utils.data import DataLoader -from torchvision.transforms import ToTensor, Resize +from torchvision.transforms import ToTensor from tqdm import tqdm -import numpy as np -import utils from data.image_folder_dataset import ImageFolderDataset -from models.resnet_with_checkpointing import resnet50 -from models.spinenet_arch import SpineNet - +from models.classifiers.resnet_with_checkpointing import resnet50 # Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved # and the distance is computed across the channel dimension. -from utils import util from utils.kmeans import kmeans, kmeans_predict from utils.options import dict_to_nonedict diff --git a/codes/scripts/byol/byol_segformer_playground.py b/codes/scripts/byol/byol_segformer_playground.py index 736d32a6..2fc9db51 100644 --- a/codes/scripts/byol/byol_segformer_playground.py +++ b/codes/scripts/byol/byol_segformer_playground.py @@ -1,5 +1,4 @@ import os -import shutil import torch import torch.nn as nn @@ -7,20 +6,14 @@ import torch.nn.functional as F import torchvision from PIL import Image from torch.utils.data import DataLoader -from torchvision.transforms import ToTensor, Resize, Normalize +from torchvision.transforms import ToTensor, Normalize from tqdm import tqdm -import numpy as np -import utils from data.image_folder_dataset import ImageFolderDataset -from models.resnet_with_checkpointing import resnet50 from models.segformer.segformer import Segformer -from models.spinenet_arch import SpineNet - # Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved # and the distance is computed across the channel dimension. -from utils import util from utils.kmeans import kmeans, kmeans_predict from utils.options import dict_to_nonedict diff --git a/codes/scripts/byol/byol_uresnet_playground.py b/codes/scripts/byol/byol_uresnet_playground.py index a09dd83c..5bbf69ea 100644 --- a/codes/scripts/byol/byol_uresnet_playground.py +++ b/codes/scripts/byol/byol_uresnet_playground.py @@ -1,5 +1,4 @@ import os -import shutil from random import shuffle import matplotlib.cm as cm @@ -7,26 +6,15 @@ import torch import torch.nn as nn import torch.nn.functional as F import torchvision -from PIL import Image from torch.utils.data import DataLoader from torchvision.models.resnet import Bottleneck -from torchvision.transforms import ToTensor, Resize from tqdm import tqdm -import numpy as np -import utils from data.image_folder_dataset import ImageFolderDataset -from models.pixel_level_contrastive_learning.resnet_unet import UResNet50 -from models.pixel_level_contrastive_learning.resnet_unet_2 import UResNet50_2 from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3 -from models.resnet_with_checkpointing import resnet50 -from models.spinenet_arch import SpineNet - # Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved # and the distance is computed across the channel dimension. -from scripts.byol.byol_spinenet_playground import find_similar_latents, create_latent_database -from utils import util from utils.kmeans import kmeans, kmeans_predict from utils.options import dict_to_nonedict