From 34f8c8641fc521d26ae04732aae419dc6df20e6e Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 11 Jan 2021 20:09:16 -0700 Subject: [PATCH] Support training imagenet classifier --- codes/data/torch_dataset.py | 45 +- .../resnet_unet_2.py | 152 ++++++ codes/models/resnet_with_checkpointing.py | 2 +- codes/models/vqvae/kmeans_mask_producer.py | 4 +- codes/models/vqvae/scaled_weight_conv.py | 57 ++- codes/models/weighted_conv_resnet.py | 441 ++++++++++++++++++ codes/requirements.txt | 3 +- codes/scripts/byol/byol_uresnet_playground.py | 28 +- codes/scripts/folderize_imagenet_val.py | 26 ++ codes/train.py | 2 +- .../trainer/eval/categorization_loss_eval.py | 97 ++++ codes/trainer/steps.py | 3 + 12 files changed, 824 insertions(+), 36 deletions(-) create mode 100644 codes/models/pixel_level_contrastive_learning/resnet_unet_2.py create mode 100644 codes/models/weighted_conv_resnet.py create mode 100644 codes/scripts/folderize_imagenet_val.py create mode 100644 codes/trainer/eval/categorization_loss_eval.py diff --git a/codes/data/torch_dataset.py b/codes/data/torch_dataset.py index 01875bfb..7015eef9 100644 --- a/codes/data/torch_dataset.py +++ b/codes/data/torch_dataset.py @@ -10,22 +10,47 @@ class TorchDataset(Dataset): "mnist": datasets.MNIST, "fmnist": datasets.FashionMNIST, "cifar10": datasets.CIFAR10, + "imagenet": datasets.ImageNet, + "imagefolder": datasets.ImageFolder } - transforms = [] - if opt['flip']: - transforms.append(T.RandomHorizontalFlip()) - if opt['crop_sz']: - transforms.append(T.RandomCrop(opt['crop_sz'], padding=opt['padding'], padding_mode="reflect")) - transforms.append(T.ToTensor()) + normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + if opt['train']: + transforms = [ + T.RandomResizedCrop(opt['image_size']), + T.RandomHorizontalFlip(), + T.ToTensor(), + normalize, + ] + else: + transforms = [ + T.Resize(opt['val_resize']), + T.CenterCrop(opt['image_size']), + T.ToTensor(), + normalize, + ] transforms = T.Compose(transforms) - is_for_training = opt['test'] if 'test' in opt.keys() else True - self.dataset = DATASET_MAP[opt['dataset']](opt['datapath'], train=is_for_training, download=True, transform=transforms) + self.dataset = DATASET_MAP[opt['dataset']](transform=transforms, **opt['kwargs']) self.len = opt['fixed_len'] if 'fixed_len' in opt.keys() else len(self.dataset) def __getitem__(self, item): - underlying_item = self.dataset[item][0] - return {'lq': underlying_item, 'hq': underlying_item, + underlying_item, lbl = self.dataset[item] + return {'lq': underlying_item, 'hq': underlying_item, 'labels': lbl, 'LQ_path': str(item), 'GT_path': str(item)} def __len__(self): return self.len + +if __name__ == '__main__': + opt = { + 'flip': True, + 'crop_sz': None, + 'dataset': 'imagefolder', + 'resize': 256, + 'center_crop': 224, + 'normalize': True, + 'kwargs': { + 'root': 'F:\\4k6k\\datasets\\images\\imagenet_2017\\val', + } + } + set = TorchDataset(opt) + j = set[0] diff --git a/codes/models/pixel_level_contrastive_learning/resnet_unet_2.py b/codes/models/pixel_level_contrastive_learning/resnet_unet_2.py new file mode 100644 index 00000000..46bc747f --- /dev/null +++ b/codes/models/pixel_level_contrastive_learning/resnet_unet_2.py @@ -0,0 +1,152 @@ +# Resnet implementation that adds a u-net style up-conversion component to output values at a +# specified pixel density. +# +# The downsampling part of the network is compatible with the built-in torch resnet for use in +# transfer learning. +# +# Only resnet50 currently supported. + +import torch +import torch.nn as nn +from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3 +from torchvision.models.utils import load_state_dict_from_url +import torchvision + + +from trainer.networks import register_model +from utils.util import checkpoint, opt_get + + +class ReverseBottleneck(nn.Module): + + def __init__(self, inplanes, planes, groups=1, passthrough=False, + base_width=64, dilation=1, norm_layer=None): + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + self.passthrough = passthrough + if passthrough: + self.integrate = conv1x1(inplanes*2, inplanes) + self.bn_integrate = norm_layer(inplanes) + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, groups, dilation) + self.bn2 = norm_layer(width) + self.residual_upsample = nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + conv1x1(width, width), + norm_layer(width), + ) + self.conv3 = conv1x1(width, planes) + self.bn3 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.upsample = nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + conv1x1(inplanes, planes), + norm_layer(planes), + ) + + def forward(self, x, passthrough=None): + if self.passthrough: + x = self.bn_integrate(self.integrate(torch.cat([x, passthrough], dim=1))) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.residual_upsample(out) + + out = self.conv3(out) + out = self.bn3(out) + + identity = self.upsample(x) + + out = out + identity + out = self.relu(out) + + return out + + +class UResNet50(torchvision.models.resnet.ResNet): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, out_dim=128): + super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, + replace_stride_with_dilation, norm_layer) + if norm_layer is None: + norm_layer = nn.BatchNorm2d + ''' + # For reference: + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + ''' + uplayers = [] + inplanes = 2048 + first = True + for i in range(2): + uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first)) + inplanes = inplanes // 2 + first = False + self.uplayers = nn.ModuleList(uplayers) + self.tail = nn.Sequential(conv1x1(1024, 512), + norm_layer(512), + nn.ReLU(), + conv3x3(512, 512), + norm_layer(512), + nn.ReLU(), + conv1x1(512, out_dim)) + + del self.fc # Not used in this implementation and just consumes a ton of GPU memory. + + + def _forward_impl(self, x): + # Should be the exact same implementation of torchvision.models.resnet.ResNet.forward_impl, + # except using checkpoints on the body conv layers. + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x1 = checkpoint(self.layer1, x) + x2 = checkpoint(self.layer2, x1) + x3 = checkpoint(self.layer3, x2) + x4 = checkpoint(self.layer4, x3) + unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused. + + x = checkpoint(self.uplayers[0], x4) + x = checkpoint(self.uplayers[1], x, x3) + #x = checkpoint(self.uplayers[2], x, x2) + #x = checkpoint(self.uplayers[3], x, x1) + + return checkpoint(self.tail, torch.cat([x, x2], dim=1)) + + def forward(self, x): + return self._forward_impl(x) + + +@register_model +def register_u_resnet50(opt_net, opt): + model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim']) + if opt_get(opt_net, ['use_pretrained_base'], False): + state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True) + model.load_state_dict(state_dict, strict=False) + return model + + +if __name__ == '__main__': + model = UResNet50(Bottleneck, [3,4,6,3]) + samp = torch.rand(1,3,224,224) + model(samp) + # For pixpro: attach to "tail.3" diff --git a/codes/models/resnet_with_checkpointing.py b/codes/models/resnet_with_checkpointing.py index 94134d77..eafea4ea 100644 --- a/codes/models/resnet_with_checkpointing.py +++ b/codes/models/resnet_with_checkpointing.py @@ -192,7 +192,7 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs): @register_model -def register_resnet52(opt_net, opt): +def register_resnet50(opt_net, opt): model = resnet50(pretrained=opt_net['pretrained']) if opt_net['custom_head_logits']: model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits']) diff --git a/codes/models/vqvae/kmeans_mask_producer.py b/codes/models/vqvae/kmeans_mask_producer.py index 33d687d7..c1ca5fc1 100644 --- a/codes/models/vqvae/kmeans_mask_producer.py +++ b/codes/models/vqvae/kmeans_mask_producer.py @@ -10,11 +10,11 @@ from utils.util import opt_get class UResnetMaskProducer(nn.Module): - def __init__(self, pretrained_uresnet_path, kmeans_centroid_path, mask_scales=[.125,.25,.5,1]): + def __init__(self, pretrained_uresnet_path, kmeans_centroid_path, mask_scales=[.125,.25,.5,1], tail_dim=512): super().__init__() _, centroids = torch.load(kmeans_centroid_path) self.centroids = nn.Parameter(centroids) - self.ures = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda') + self.ures = UResNet50(Bottleneck, [3,4,6,3], out_dim=tail_dim).to('cuda') self.mask_scales = mask_scales sd = torch.load(pretrained_uresnet_path) diff --git a/codes/models/vqvae/scaled_weight_conv.py b/codes/models/vqvae/scaled_weight_conv.py index a6c212da..1e5d6d54 100644 --- a/codes/models/vqvae/scaled_weight_conv.py +++ b/codes/models/vqvae/scaled_weight_conv.py @@ -48,9 +48,8 @@ class ScaledWeightConv(_ConvNd): w.FOR_SCALE_SHIFT = True s.FOR_SCALE_SHIFT = True # This should probably be configurable at some point. - for p in self.parameters(): - if not hasattr(p, "FOR_SCALE_SHIFT"): - p.DO_NOT_TRAIN = True + self.weight.DO_NOT_TRAIN = True + self.weight.requires_grad = False def _weighted_conv_forward(self, input, weight): if self.padding_mode != 'zeros': @@ -60,7 +59,12 @@ class ScaledWeightConv(_ConvNd): return F.conv2d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) - def forward(self, input: Tensor, masks: dict) -> Tensor: + def forward(self, input: Tensor, masks: dict = None) -> Tensor: + if masks is None: + # An alternate "mode" of operation is the masks are injected as parameters. + assert hasattr(self, 'masks') + masks = self.masks + # This is an exceptionally inefficient way of achieving this functionality. The hope is that if this is any # good at all, this can be made more efficient by performing a single conv pass with multiple masks. weighted_convs = [self._weighted_conv_forward(input, self.weight * scale + shift) for scale, shift in zip(self.weight_scales, self.shifts)] @@ -72,6 +76,20 @@ class ScaledWeightConv(_ConvNd): return index_2d(weighted_convs, masks[needed_mask]) +def create_wrapped_conv_from_template(conv: nn.Conv2d, breadth: int): + wrapped = ScaledWeightConv(conv.in_channels, + conv.out_channels, + conv.kernel_size[0], + conv.stride[0], + conv.padding[0], + conv.dilation[0], + conv.groups, + conv.bias, + conv.padding_mode, + breadth) + return wrapped + + # Drop-in implementation of ConvTranspose2d that can apply masked scales&shifts to the convolution weights. class ScaledWeightConvTranspose(_ConvTransposeNd): def __init__( @@ -102,9 +120,8 @@ class ScaledWeightConvTranspose(_ConvTransposeNd): w.FOR_SCALE_SHIFT = True s.FOR_SCALE_SHIFT = True # This should probably be configurable at some point. - for nm, p in self.named_parameters(): - if nm == 'weight': - p.DO_NOT_TRAIN = True + self.weight.DO_NOT_TRAIN = True + self.weight.requires_grad = False def _conv_transpose_forward(self, input, weight, output_size) -> Tensor: if self.padding_mode != 'zeros': @@ -117,7 +134,12 @@ class ScaledWeightConvTranspose(_ConvTransposeNd): input, weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) - def forward(self, input: Tensor, masks: dict, output_size: Optional[List[int]] = None) -> Tensor: + def forward(self, input: Tensor, masks: dict = None, output_size: Optional[List[int]] = None) -> Tensor: + if masks is None: + # An alternate "mode" of operation is the masks are injected as parameters. + assert hasattr(self, 'masks') + masks = self.masks + # This is an exceptionally inefficient way of achieving this functionality. The hope is that if this is any # good at all, this can be made more efficient by performing a single conv pass with multiple masks. weighted_convs = [self._conv_transpose_forward(input, self.weight * scale + shift, output_size) @@ -128,3 +150,22 @@ class ScaledWeightConvTranspose(_ConvTransposeNd): assert needed_mask in masks.keys() return index_2d(weighted_convs, masks[needed_mask]) + + +def create_wrapped_conv_transpose_from_template(conv: nn.Conv2d, breadth: int): + wrapped = ScaledWeightConvTranspose(conv.in_channels, + conv.out_channels, + conv.kernel_size, + conv.stride, + conv.padding, + conv.output_padding, + conv.groups, + conv.bias, + conv.dilation, + conv.padding_mode, + breadth) + wrapped.weight = conv.weight + wrapped.weight.DO_NOT_TRAIN = True + wrapped.weight.requires_grad = False + wrapped.bias = conv.bias + return wrapped diff --git a/codes/models/weighted_conv_resnet.py b/codes/models/weighted_conv_resnet.py new file mode 100644 index 00000000..6261e1cc --- /dev/null +++ b/codes/models/weighted_conv_resnet.py @@ -0,0 +1,441 @@ +import torch +import torchvision +from torch import Tensor +import torch.nn as nn +from torchvision.models.utils import load_state_dict_from_url +from typing import Type, Any, Callable, Union, List, Optional, OrderedDict, Iterator + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + +from models.vqvae.scaled_weight_conv import ScaledWeightConv +from trainer.networks import register_model +from utils.util import checkpoint + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1, breadth: int = 8) -> ScaledWeightConv: + """3x3 convolution with padding""" + return ScaledWeightConv(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation, breadth=breadth) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1, breadth: int = 8) -> ScaledWeightConv: + """1x1 convolution""" + return ScaledWeightConv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False, breadth=breadth) + + +# Provides similar API to nn.Sequential, but handles feed-forward networks that need to feed masks into their convolutions. +class MaskedSequential(nn.Module): + def __init__(self, *args): + super().__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + + def __iter__(self) -> Iterator[nn.Module]: + return iter(self._modules.values()) + + def forward(self, x): + mask = self.masks + for m in self: + if isinstance(m, ScaledWeightConv) or isinstance(m, BasicBlock) or isinstance(m, Bottleneck): + x = m(x, mask) + else: + x = m(x) + return x + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + breadth: int = 8 + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride, breadth=breadth) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, breadth=breadth) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + identity = x + + out = self.conv1(x, mask) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out, mask) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x, mask) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + breadth: int = 8 + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width, breadth=breadth) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation, breadth=breadth) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion, breadth=breadth) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + identity = x + + out = self.conv1(x, mask) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out, mask) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out, mask) + out = self.bn3(out) + + if self.downsample is not None: + self.downsample.masks = mask + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + breadth: int = 8 + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = ScaledWeightConv(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False, breadth=breadth) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], breadth) + self.layer2 = self._make_layer(block, 128, layers[1], breadth, stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], breadth, stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], breadth, stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, ScaledWeightConv): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, breadth: int, + stride: int = 1, dilate: bool = False) -> MaskedSequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = MaskedSequential( + conv1x1(self.inplanes, planes * block.expansion, stride, breadth=breadth), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer, breadth=breadth)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer, breadth=breadth)) + + return MaskedSequential(*layers) + + def _forward_impl(self, x: Tensor, mask: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x, mask) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + for m in [self.layer1, self.layer2, self.layer3, self.layer4]: + m.masks = mask + x = checkpoint(self.layer1, x) + x = checkpoint(self.layer2, x) + x = checkpoint(self.layer3, x) + x = checkpoint(self.layer4, x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + return self._forward_impl(x, mask) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict, strict=False) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +@register_model +def register_resnet50_weighted_conv(opt_net, opt): + model = resnet50(pretrained=opt_net['pretrained'], **opt_net['kwargs']) + return model + + +if __name__ == '__main__': + orig = torchvision.models.resnet.resnet50(pretrained=True) + mod = resnet50(pretrained=True, breadth=4) + idim = 224 + masks = {} + for j in range(6): + cdim = idim // (2 ** j) + masks[cdim] = torch.zeros((1,1,cdim,cdim), dtype=torch.long) + i = torch.rand(1,3,idim,idim) + r1 = mod(i, masks) + r2 = orig(i) diff --git a/codes/requirements.txt b/codes/requirements.txt index b8b63f26..0e218500 100644 --- a/codes/requirements.txt +++ b/codes/requirements.txt @@ -15,4 +15,5 @@ pytorch_fid==0.1.1 kornia linear_attention_transformer vector_quantize_pytorch -orjson \ No newline at end of file +orjson +einops \ No newline at end of file diff --git a/codes/scripts/byol/byol_uresnet_playground.py b/codes/scripts/byol/byol_uresnet_playground.py index 35b4133d..1eb01068 100644 --- a/codes/scripts/byol/byol_uresnet_playground.py +++ b/codes/scripts/byol/byol_uresnet_playground.py @@ -59,7 +59,8 @@ def im_norm(x): def get_image_folder_dataloader(batch_size, num_workers, target_size=256): dataset_opt = dict_to_nonedict({ 'name': 'amalgam', - 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], + 'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'], + #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], 'weights': [1], @@ -94,22 +95,23 @@ def produce_latent_dict(model): id += batch_size if id > 1000: print("Saving checkpoint..") - torch.save((latents, paths), '../results.pth') + torch.save((latents, paths), '../imagenet_latent_dict.pth') id = 0 def build_kmeans(): - latents, _ = torch.load('../results.pth') + latents, _ = torch.load('../imagenet_latent_dict.pth') latents = torch.cat(latents, dim=0).to('cuda') - cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0')) - torch.save((cluster_ids_x, cluster_centers), '../k_means.pth') + cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=4, distance="euclidean", device=torch.device('cuda:0')) + torch.save((cluster_ids_x, cluster_centers), '../k_means_imagenet.pth') def use_kmeans(): - _, centers = torch.load('../experiments/k_means_uresnet_512.pth') + _, centers = torch.load('../k_means_imagenet.pth') + centers = centers.to('cuda') batch_size = 8 num_workers = 0 - dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=512) + dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256) colormap = cm.get_cmap('viridis', 8) for i, batch in enumerate(tqdm(dataloader)): hq = batch['hq'].to('cuda') @@ -117,16 +119,16 @@ def use_kmeans(): b, c, h, w = l.shape dim = b*h*w l = l.permute(0,2,3,1).reshape(dim,c) - pred = kmeans_predict(l, centers, device=l.device) + pred = kmeans_predict(l, centers) pred = pred.reshape(b,h,w) - img = torch.tensor(colormap(pred[:, :, :].detach().numpy())) + img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy())) torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=8, mode="nearest"), f"{i}_categories.png") torchvision.utils.save_image(hq, f"{i}_hq.png") if __name__ == '__main__': - pretrained_path = '../experiments/uresnet_pixpro_512.pth' - model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda') + pretrained_path = '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth' + model = UResNet50(Bottleneck, [3,4,6,3], out_dim=256).to('cuda') sd = torch.load(pretrained_path) resnet_sd = {} for k, v in sd.items(): @@ -139,5 +141,5 @@ if __name__ == '__main__': #find_similar_latents(model, 0, 8, structural_euc_dist) #create_latent_database(model, batch_size=32) #produce_latent_dict(model) - #build_kmeans() - use_kmeans() + build_kmeans() + #use_kmeans() diff --git a/codes/scripts/folderize_imagenet_val.py b/codes/scripts/folderize_imagenet_val.py new file mode 100644 index 00000000..69db93bd --- /dev/null +++ b/codes/scripts/folderize_imagenet_val.py @@ -0,0 +1,26 @@ +from glob import glob + +import torch +import os +import shutil + +if __name__ == '__main__': + index_map_file = 'F:\\4k6k\\datasets\\images\\imagenet_2017\\imagenet_index_to_train_folder_name_map.pth' + ground_truth = 'F:\\4k6k\\datasets\\images\\imagenet_2017\\validation_ground_truth.txt' + val_path = 'F:\\4k6k\\datasets\\images\\imagenet_2017\\val' + + index_map = torch.load(index_map_file) + + for folder in index_map.values(): + os.makedirs(os.path.join(val_path, folder), exist_ok=True) + + gtfile = open(ground_truth, 'r') + gtids = [] + for line in gtfile: + gtids.append(int(line.strip())) + gtfile.close() + + for i, img_file in enumerate(glob(os.path.join(val_path, "*.JPEG"))): + shutil.move(img_file, os.path.join(val_path, index_map[gtids[i]], + os.path.basename(img_file))) + print("Done!") diff --git a/codes/train.py b/codes/train.py index 9117131f..59d821c9 100644 --- a/codes/train.py +++ b/codes/train.py @@ -295,7 +295,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_imgset_vqvae_stage1/train_imgset_vqvae_stage1_5.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imagenet_pixpro_resnet.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/eval/categorization_loss_eval.py b/codes/trainer/eval/categorization_loss_eval.py new file mode 100644 index 00000000..57fb33ea --- /dev/null +++ b/codes/trainer/eval/categorization_loss_eval.py @@ -0,0 +1,97 @@ +import torch +import torchvision +from torch.nn.functional import interpolate +from torch.utils.data import DataLoader +from torchvision import transforms +from tqdm import tqdm + +import trainer.eval.evaluator as evaluator +from models.vqvae.kmeans_mask_producer import UResnetMaskProducer +from utils.util import opt_get + + +class CategorizationLossEvaluator(evaluator.Evaluator): + def __init__(self, model, opt_eval, env): + super().__init__(model, opt_eval, env) + self.batch_sz = opt_eval['batch_size'] + assert self.batch_sz is not None + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + self.dataset = torchvision.datasets.ImageFolder( + 'F:\\4k6k\\datasets\\images\\imagenet_2017\\val', + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4) + self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 + self.masking = opt_get(opt_eval, ['masking'], True) + if self.masking: + self.mask_producer = UResnetMaskProducer(pretrained_uresnet_path= '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth', + kmeans_centroid_path='../experiments/k_means_uresnet_imagenet_256.pth', + mask_scales=[.03125, .0625, .125, .25, .5, 1.0], + tail_dim=256).to('cuda') + + def accuracy(self, output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target[None]) + + res = [] + for k in topk: + correct_k = correct[:k].flatten().sum(dtype=torch.float32) + res.append(correct_k * (100.0 / batch_size)) + return res + + def perform_eval(self): + counter = 0.0 + ce_loss = 0.0 + top_5_acc = 0.0 + top_1_acc = 0.0 + + self.model.eval() + with torch.no_grad(): + for hq, labels in tqdm(self.dataloader): + hq = hq.to(self.env['device']) + labels = labels.to(self.env['device']) + if self.masking: + masks = self.mask_producer(hq) + logits = self.model(hq, masks) + else: + logits = self.model(hq) + if not isinstance(logits, list) and not isinstance(logits, tuple): + logits = [logits] + logits = logits[self.gen_output_index] + ce_loss += torch.nn.functional.cross_entropy(logits, labels).detach() + t1, t5 = self.accuracy(logits, labels, (1, 5)) + top_1_acc += t1.detach() + top_5_acc += t5.detach() + counter += len(hq) / self.batch_sz + self.model.train() + + return {"val_cross_entropy": ce_loss / counter, + "top_5_accuracy": top_5_acc / counter, + "top_1_accuracy": top_1_acc / counter } + + +if __name__ == '__main__': + from torchvision.models import resnet50 + model = resnet50(pretrained=True).to('cuda') + opt = { + 'batch_size': 128, + 'gen_index': 0, + 'masking': False + } + env = { + 'device': 'cuda', + + } + eval = CategorizationLossEvaluator(model, opt, env) + print(eval.perform_eval()) diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 07c24e2b..df4cade0 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -107,6 +107,9 @@ class ConfigurableStep(Module): optSGD = SGDNoBiasMomentum(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay']) opt = LARC(optSGD, trust_coefficient=opt_config['lars_coefficient']) + elif self.step_opt['optimizer'] == 'sgd': + from torch.optim import SGD + opt = SGD(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay']) opt._config = opt_config # This is a bit seedy, but we will need these configs later. opt._config['network'] = net_name self.optimizers.append(opt)