resnet-classifier/image_classifier/data.py

221 lines
5.6 KiB
Python
Executable File

# todo: clean this mess up
import copy
# import h5py
import json
import logging
import numpy as np
import os
import random
import torch
import math
from .config import cfg
from collections import defaultdict
from functools import cache, cached_property
from itertools import groupby, zip_longest
from pathlib import Path
from typing import Any
from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset
import torchvision.transforms as transforms
from tqdm.auto import tqdm
from PIL import Image
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__)
@cache
def get_symmap():
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
@cache
def _get_symbols( content ):
content = content.replace("O", "0")
return [f"<s>"] + [ p for p in content ] + [f"</s>"]
class Dataset(_Dataset):
def __init__(
self,
paths,
width=300,
height=80,
symmap=get_symmap(),
training=False,
):
super().__init__()
self._head = None
self.paths = paths
self.width = width
self.height = height
self.symmap = symmap
self.training = training
self.transform = transforms.Compose([
#transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@cached_property
def symbols(self):
return sorted(set().union(*[_get_symbols(path.stem) for path in self.paths]))
def __getitem__(self, index):
path = self.paths[index]
# stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here
try:
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
except Exception as e:
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
raise e
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB
return dict(
index=index,
path=path,
image=image,
text=text,
)
def head_(self, n):
self._head = n
def training_(self, value):
self.training = value
def __len__(self):
return min(len(self.paths), self._head or len(self.paths))
def pin_memory(self):
self.text = self.text.pin_memory()
self.image = self.image.pin_memory()
return self
def collate_fn(samples: list[dict]):
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
return batch
def _seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def _create_dataloader(dataset, training):
return DataLoader(
dataset=dataset,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
shuffle=True, # training
drop_last=training,
num_workers=cfg.dataset.workers,
collate_fn=collate_fn,
persistent_workers=cfg.dataset.workers > 0,
pin_memory=False, # True,
worker_init_fn=_seed_worker,
)
def _load_train_val_paths( val_ratio=0.1 ):
paths = []
train_paths = []
val_paths = []
print(cfg.dataset.training)
for data_dir in cfg.dataset.training:
paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0:
random.seed(0)
random.shuffle(paths)
train_paths.extend(paths)
if len(cfg.dataset.validation) == 0:
val_len = math.floor(len(train_paths) * val_ratio)
train_len = math.floor(len(train_paths) * (1 - val_ratio))
print(val_len, train_len)
val_paths = train_paths[:-val_len]
train_paths = train_paths[:train_len]
else:
for data_dir in cfg.dataset.validation:
paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0:
random.seed(0)
random.shuffle(paths)
val_paths.extend(paths)
train_paths, val_paths = map(sorted, [train_paths, val_paths])
if len(train_paths) == 0:
raise RuntimeError(f"Failed to find any .png file in {cfg.dataset.training}.")
# to get it to shut up
if len(val_paths) == 0:
val_paths = [ train_paths[0] ]
return train_paths, val_paths
@cfg.diskcache()
def create_datasets():
train_paths, val_paths = _load_train_val_paths()
train_dataset = Dataset(
train_paths,
training=True,
)
val_dataset = Dataset(
val_paths,
train_dataset.symmap,
)
val_dataset.head_(cfg.evaluation.size)
return train_dataset, val_dataset
def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets()
subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.head_(cfg.evaluation.size)
subtrain_dataset.training_(False)
train_dl = _create_dataloader(train_dataset, training=True)
val_dl = _create_dataloader(val_dataset, training=False)
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
_logger.info(str(train_dataset.symmap))
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
assert isinstance(subtrain_dl.dataset, Dataset)
return train_dl, subtrain_dl, val_dl
"""
if __name__ == "__main__":
create_dataset_hdf5()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
sample = train_dl.dataset[0]
print(sample)
"""