CIFAR stuff
- Extract coarse labels for the CIFAR dataset - Add simple resnet that branches lower layers based on coarse labels - Some other cleanup
This commit is contained in:
parent
80d4404367
commit
fb405d9ef1
176
codes/data/cifar.py
Normal file
176
codes/data/cifar.py
Normal file
|
@ -0,0 +1,176 @@
|
||||||
|
# A copy of the cifar dataset from torch which also returns coarse labels.
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
import os.path
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
from typing import Any, Callable, Optional, Tuple
|
||||||
|
|
||||||
|
from torchvision.datasets import VisionDataset
|
||||||
|
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
|
||||||
|
|
||||||
|
|
||||||
|
class CIFAR10(VisionDataset):
|
||||||
|
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root (string): Root directory of dataset where directory
|
||||||
|
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
|
||||||
|
train (bool, optional): If True, creates dataset from training set, otherwise
|
||||||
|
creates from test set.
|
||||||
|
transform (callable, optional): A function/transform that takes in an PIL image
|
||||||
|
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||||
|
target_transform (callable, optional): A function/transform that takes in the
|
||||||
|
target and transforms it.
|
||||||
|
download (bool, optional): If true, downloads the dataset from the internet and
|
||||||
|
puts it in root directory. If dataset is already downloaded, it is not
|
||||||
|
downloaded again.
|
||||||
|
|
||||||
|
"""
|
||||||
|
base_folder = 'cifar-10-batches-py'
|
||||||
|
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||||
|
filename = "cifar-10-python.tar.gz"
|
||||||
|
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
|
||||||
|
train_list = [
|
||||||
|
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
||||||
|
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
|
||||||
|
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
|
||||||
|
['data_batch_4', '634d18415352ddfa80567beed471001a'],
|
||||||
|
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
|
||||||
|
]
|
||||||
|
|
||||||
|
test_list = [
|
||||||
|
['test_batch', '40351d587109b95175f43aff81a1287e'],
|
||||||
|
]
|
||||||
|
meta = {
|
||||||
|
'filename': 'batches.meta',
|
||||||
|
'key': 'label_names',
|
||||||
|
'md5': '5ff9c542aee3614f3951f8cda6e48888',
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: str,
|
||||||
|
train: bool = True,
|
||||||
|
transform: Optional[Callable] = None,
|
||||||
|
target_transform: Optional[Callable] = None,
|
||||||
|
download: bool = False,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super(CIFAR10, self).__init__(root, transform=transform,
|
||||||
|
target_transform=target_transform)
|
||||||
|
|
||||||
|
self.train = train # training set or test set
|
||||||
|
|
||||||
|
if download:
|
||||||
|
self.download()
|
||||||
|
|
||||||
|
if not self._check_integrity():
|
||||||
|
raise RuntimeError('Dataset not found or corrupted.' +
|
||||||
|
' You can use download=True to download it')
|
||||||
|
|
||||||
|
if self.train:
|
||||||
|
downloaded_list = self.train_list
|
||||||
|
else:
|
||||||
|
downloaded_list = self.test_list
|
||||||
|
|
||||||
|
self.data: Any = []
|
||||||
|
self.targets = []
|
||||||
|
self.coarse_targets = []
|
||||||
|
|
||||||
|
# now load the picked numpy arrays
|
||||||
|
for file_name, checksum in downloaded_list:
|
||||||
|
file_path = os.path.join(self.root, self.base_folder, file_name)
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
entry = pickle.load(f, encoding='latin1')
|
||||||
|
self.data.append(entry['data'])
|
||||||
|
if 'labels' in entry:
|
||||||
|
self.targets.extend(entry['labels'])
|
||||||
|
else:
|
||||||
|
self.targets.extend(entry['fine_labels'])
|
||||||
|
self.coarse_targets.extend(entry['coarse_labels'])
|
||||||
|
|
||||||
|
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
|
||||||
|
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
||||||
|
|
||||||
|
self._load_meta()
|
||||||
|
|
||||||
|
def _load_meta(self) -> None:
|
||||||
|
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
|
||||||
|
if not check_integrity(path, self.meta['md5']):
|
||||||
|
raise RuntimeError('Dataset metadata file not found or corrupted.' +
|
||||||
|
' You can use download=True to download it')
|
||||||
|
with open(path, 'rb') as infile:
|
||||||
|
data = pickle.load(infile, encoding='latin1')
|
||||||
|
self.classes = data[self.meta['key']]
|
||||||
|
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
index (int): Index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (image, target) where target is index of the target class.
|
||||||
|
"""
|
||||||
|
img, target = self.data[index], self.targets[index]
|
||||||
|
|
||||||
|
# doing this so that it is consistent with all other datasets
|
||||||
|
# to return a PIL Image
|
||||||
|
img = Image.fromarray(img)
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
if self.target_transform is not None:
|
||||||
|
target = self.target_transform(target)
|
||||||
|
|
||||||
|
if len(self.coarse_targets) > 0:
|
||||||
|
return img, target, self.coarse_targets[index]
|
||||||
|
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def _check_integrity(self) -> bool:
|
||||||
|
root = self.root
|
||||||
|
for fentry in (self.train_list + self.test_list):
|
||||||
|
filename, md5 = fentry[0], fentry[1]
|
||||||
|
fpath = os.path.join(root, self.base_folder, filename)
|
||||||
|
if not check_integrity(fpath, md5):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def download(self) -> None:
|
||||||
|
if self._check_integrity():
|
||||||
|
print('Files already downloaded and verified')
|
||||||
|
return
|
||||||
|
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return "Split: {}".format("Train" if self.train is True else "Test")
|
||||||
|
|
||||||
|
|
||||||
|
class CIFAR100(CIFAR10):
|
||||||
|
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
||||||
|
|
||||||
|
This is a subclass of the `CIFAR10` Dataset.
|
||||||
|
"""
|
||||||
|
base_folder = 'cifar-100-python'
|
||||||
|
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
||||||
|
filename = "cifar-100-python.tar.gz"
|
||||||
|
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
||||||
|
train_list = [
|
||||||
|
['train', '16019d7e3df5f24257cddd939b257f8d'],
|
||||||
|
]
|
||||||
|
|
||||||
|
test_list = [
|
||||||
|
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
|
||||||
|
]
|
||||||
|
meta = {
|
||||||
|
'filename': 'meta',
|
||||||
|
'key': 'fine_label_names',
|
||||||
|
'md5': '7973b15100ade9c7d40fb424638fde48',
|
||||||
|
}
|
|
@ -4,6 +4,7 @@ import torchvision.transforms as T
|
||||||
from torchvision import datasets
|
from torchvision import datasets
|
||||||
|
|
||||||
# Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer.
|
# Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer.
|
||||||
|
from data.cifar import CIFAR100, CIFAR10
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,8 +13,8 @@ class TorchDataset(Dataset):
|
||||||
DATASET_MAP = {
|
DATASET_MAP = {
|
||||||
"mnist": datasets.MNIST,
|
"mnist": datasets.MNIST,
|
||||||
"fmnist": datasets.FashionMNIST,
|
"fmnist": datasets.FashionMNIST,
|
||||||
"cifar10": datasets.CIFAR10,
|
"cifar10": CIFAR10,
|
||||||
"cifar100": datasets.CIFAR100,
|
"cifar100": CIFAR100,
|
||||||
"imagenet": datasets.ImageNet,
|
"imagenet": datasets.ImageNet,
|
||||||
"imagefolder": datasets.ImageFolder
|
"imagefolder": datasets.ImageFolder
|
||||||
}
|
}
|
||||||
|
@ -39,8 +40,15 @@ class TorchDataset(Dataset):
|
||||||
self.offset = opt_get(opt, ['offset'], 0)
|
self.offset = opt_get(opt, ['offset'], 0)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
underlying_item, lbl = self.dataset[item+self.offset]
|
item = self.dataset[item+self.offset]
|
||||||
return {'lq': underlying_item, 'hq': underlying_item, 'labels': lbl,
|
if len(item) == 2:
|
||||||
|
underlying_item, lbl = item
|
||||||
|
coarselbl = None
|
||||||
|
elif len(item) == 3:
|
||||||
|
underlying_item, lbl, coarselbl = item
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return {'lq': underlying_item, 'hq': underlying_item, 'labels': lbl, 'coarse_labels': coarselbl,
|
||||||
'LQ_path': str(item), 'GT_path': str(item)}
|
'LQ_path': str(item), 'GT_path': str(item)}
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
0
codes/models/classifiers/__init__.py
Normal file
0
codes/models/classifiers/__init__.py
Normal file
175
codes/models/classifiers/cifar_resnet_branched.py
Normal file
175
codes/models/classifiers/cifar_resnet_branched.py
Normal file
|
@ -0,0 +1,175 @@
|
||||||
|
"""resnet in pytorch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
|
||||||
|
|
||||||
|
Deep Residual Learning for Image Recognition
|
||||||
|
https://arxiv.org/abs/1512.03385v1
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from trainer.networks import register_model
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
"""Basic Block for resnet 18 and resnet 34
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
#BasicBlock and BottleNeck block
|
||||||
|
#have different output size
|
||||||
|
#we use class attribute expansion
|
||||||
|
#to distinct
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, stride=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
#residual function
|
||||||
|
self.residual_function = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||||
|
)
|
||||||
|
|
||||||
|
#shortcut
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
|
||||||
|
#the shortcut output dimension is not the same with residual function
|
||||||
|
#use 1*1 convolution to match the dimension
|
||||||
|
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||||
|
|
||||||
|
class BottleNeck(nn.Module):
|
||||||
|
"""Residual block for resnet over 50 layers
|
||||||
|
|
||||||
|
"""
|
||||||
|
expansion = 4
|
||||||
|
def __init__(self, in_channels, out_channels, stride=1):
|
||||||
|
super().__init__()
|
||||||
|
self.residual_function = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
|
||||||
|
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetTail(nn.Module):
|
||||||
|
def __init__(self, block, num_block, num_classes=100):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.in_channels = 128
|
||||||
|
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
|
||||||
|
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
|
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||||
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
layers = []
|
||||||
|
for stride in strides:
|
||||||
|
layers.append(block(self.in_channels, out_channels, stride))
|
||||||
|
self.in_channels = out_channels * block.expansion
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self.conv4_x(x)
|
||||||
|
output = self.conv5_x(output)
|
||||||
|
output = self.avg_pool(output)
|
||||||
|
output = output.view(output.size(0), -1)
|
||||||
|
output = self.fc(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, block, num_block, num_classes=100, num_tails=20):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = 64
|
||||||
|
self.conv1 = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(inplace=True))
|
||||||
|
|
||||||
|
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
|
||||||
|
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
|
||||||
|
self.tails = nn.ModuleList([ResNetTail(block, num_block, num_classes) for _ in range(num_tails)])
|
||||||
|
|
||||||
|
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||||
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
layers = []
|
||||||
|
for stride in strides:
|
||||||
|
layers.append(block(self.in_channels, out_channels, stride))
|
||||||
|
self.in_channels = out_channels * block.expansion
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x, coarse_label):
|
||||||
|
output = self.conv1(x)
|
||||||
|
output = self.conv2_x(output)
|
||||||
|
output = self.conv3_x(output)
|
||||||
|
bs = output.shape[0]
|
||||||
|
tailouts = []
|
||||||
|
for t in self.tails:
|
||||||
|
tailouts.append(t(output))
|
||||||
|
tailouts = torch.stack(tailouts, dim=0)
|
||||||
|
return (tailouts[coarse_label] * torch.eye(n=bs).view(bs,bs,1)).sum(dim=1)
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_cifar_resnet18(opt_net, opt):
|
||||||
|
""" return a ResNet 18 object
|
||||||
|
"""
|
||||||
|
return ResNet(BasicBlock, [2, 2, 2, 2])
|
||||||
|
|
||||||
|
def resnet34():
|
||||||
|
""" return a ResNet 34 object
|
||||||
|
"""
|
||||||
|
return ResNet(BasicBlock, [3, 4, 6, 3])
|
||||||
|
|
||||||
|
def resnet50():
|
||||||
|
""" return a ResNet 50 object
|
||||||
|
"""
|
||||||
|
return ResNet(BottleNeck, [3, 4, 6, 3])
|
||||||
|
|
||||||
|
def resnet101():
|
||||||
|
""" return a ResNet 101 object
|
||||||
|
"""
|
||||||
|
return ResNet(BottleNeck, [3, 4, 23, 3])
|
||||||
|
|
||||||
|
def resnet152():
|
||||||
|
""" return a ResNet 152 object
|
||||||
|
"""
|
||||||
|
return ResNet(BottleNeck, [3, 8, 36, 3])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = ResNet(BasicBlock, [2,2,2,2])
|
||||||
|
print(model(torch.randn(2,3,32,32), torch.LongTensor([4,19])).shape)
|
||||||
|
|
|
@ -4,22 +4,16 @@ import shutil
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision.transforms import ToTensor, Resize
|
from torchvision.transforms import ToTensor
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import utils
|
|
||||||
from data.image_folder_dataset import ImageFolderDataset
|
from data.image_folder_dataset import ImageFolderDataset
|
||||||
from models.resnet_with_checkpointing import resnet50
|
from models.classifiers.resnet_with_checkpointing import resnet50
|
||||||
from models.spinenet_arch import SpineNet
|
|
||||||
|
|
||||||
|
|
||||||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||||
# and the distance is computed across the channel dimension.
|
# and the distance is computed across the channel dimension.
|
||||||
from utils import util
|
|
||||||
from utils.kmeans import kmeans, kmeans_predict
|
from utils.kmeans import kmeans, kmeans_predict
|
||||||
from utils.options import dict_to_nonedict
|
from utils.options import dict_to_nonedict
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -7,20 +6,14 @@ import torch.nn.functional as F
|
||||||
import torchvision
|
import torchvision
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision.transforms import ToTensor, Resize, Normalize
|
from torchvision.transforms import ToTensor, Normalize
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import utils
|
|
||||||
from data.image_folder_dataset import ImageFolderDataset
|
from data.image_folder_dataset import ImageFolderDataset
|
||||||
from models.resnet_with_checkpointing import resnet50
|
|
||||||
from models.segformer.segformer import Segformer
|
from models.segformer.segformer import Segformer
|
||||||
from models.spinenet_arch import SpineNet
|
|
||||||
|
|
||||||
|
|
||||||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||||
# and the distance is computed across the channel dimension.
|
# and the distance is computed across the channel dimension.
|
||||||
from utils import util
|
|
||||||
from utils.kmeans import kmeans, kmeans_predict
|
from utils.kmeans import kmeans, kmeans_predict
|
||||||
from utils.options import dict_to_nonedict
|
from utils.options import dict_to_nonedict
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
from random import shuffle
|
from random import shuffle
|
||||||
|
|
||||||
import matplotlib.cm as cm
|
import matplotlib.cm as cm
|
||||||
|
@ -7,26 +6,15 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision
|
import torchvision
|
||||||
from PIL import Image
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision.models.resnet import Bottleneck
|
from torchvision.models.resnet import Bottleneck
|
||||||
from torchvision.transforms import ToTensor, Resize
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import utils
|
|
||||||
from data.image_folder_dataset import ImageFolderDataset
|
from data.image_folder_dataset import ImageFolderDataset
|
||||||
from models.pixel_level_contrastive_learning.resnet_unet import UResNet50
|
|
||||||
from models.pixel_level_contrastive_learning.resnet_unet_2 import UResNet50_2
|
|
||||||
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
|
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
|
||||||
from models.resnet_with_checkpointing import resnet50
|
|
||||||
from models.spinenet_arch import SpineNet
|
|
||||||
|
|
||||||
|
|
||||||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||||
# and the distance is computed across the channel dimension.
|
# and the distance is computed across the channel dimension.
|
||||||
from scripts.byol.byol_spinenet_playground import find_similar_latents, create_latent_database
|
|
||||||
from utils import util
|
|
||||||
from utils.kmeans import kmeans, kmeans_predict
|
from utils.kmeans import kmeans, kmeans_predict
|
||||||
from utils.options import dict_to_nonedict
|
from utils.options import dict_to_nonedict
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user