resnet-classifier/image_classifier/data.py

221 lines
5.6 KiB
Python
Executable File

# todo: clean this mess up
import copy
# import h5py
import json
import logging
#import numpy as np
import os
import random
import torch
import math
from .config import cfg
from collections import defaultdict
from functools import cache, cached_property
from itertools import groupby, zip_longest
from pathlib import Path
from typing import Any
from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset
import torchvision.transforms as transforms
from tqdm.auto import tqdm
from PIL import Image
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__)
@cache
def get_symmap():
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
@cache
def _get_symbols( content ):
content = content.replace("O", "0")
return [f"<s>"] + [ p for p in content ] + [f"</s>"]
class Dataset(_Dataset):
def __init__(
self,
paths,
width=300,
height=80,
symmap=get_symmap(),
training=False,
):
super().__init__()
self._head = None
self.paths = paths
self.width = width
self.height = height
self.symmap = symmap
self.training = training
self.transform = transforms.Compose([
#transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@cached_property
def symbols(self):
return sorted(set().union(*[_get_symbols(path.stem) for path in self.paths]))
def __getitem__(self, index):
path = self.paths[index]
# stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here
try:
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
except Exception as e:
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
raise e
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB
return dict(
index=index,
path=path,
image=image,
text=text,
)
def head_(self, n):
self._head = n
def training_(self, value):
self.training = value
def __len__(self):
return min(len(self.paths), self._head or len(self.paths))
def pin_memory(self):
self.text = self.text.pin_memory()
self.image = self.image.pin_memory()
return self
def collate_fn(samples: list[dict]):
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
return batch
def _seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
#np.random.seed(worker_seed)
random.seed(worker_seed)
def _create_dataloader(dataset, training):
return DataLoader(
dataset=dataset,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
shuffle=True, # training
drop_last=training,
num_workers=cfg.dataset.workers,
collate_fn=collate_fn,
persistent_workers=cfg.dataset.workers > 0,
pin_memory=False, # True,
worker_init_fn=_seed_worker,
)
def _load_train_val_paths( val_ratio=0.1 ):
paths = []
train_paths = []
val_paths = []
print(cfg.dataset.training)
for data_dir in cfg.dataset.training:
paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0:
random.seed(0)
random.shuffle(paths)
train_paths.extend(paths)
if len(cfg.dataset.validation) == 0:
val_len = math.floor(len(train_paths) * val_ratio)
train_len = math.floor(len(train_paths) * (1 - val_ratio))
print(val_len, train_len)
val_paths = train_paths[:-val_len]
train_paths = train_paths[:train_len]
else:
for data_dir in cfg.dataset.validation:
paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0:
random.seed(0)
random.shuffle(paths)
val_paths.extend(paths)
train_paths, val_paths = map(sorted, [train_paths, val_paths])
if len(train_paths) == 0:
raise RuntimeError(f"Failed to find any .png file in {cfg.dataset.training}.")
# to get it to shut up
if len(val_paths) == 0:
val_paths = [ train_paths[0] ]
return train_paths, val_paths
@cfg.diskcache()
def create_datasets():
train_paths, val_paths = _load_train_val_paths()
train_dataset = Dataset(
train_paths,
training=True,
)
val_dataset = Dataset(
val_paths,
train_dataset.symmap,
)
val_dataset.head_(cfg.evaluation.size)
return train_dataset, val_dataset
def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets()
subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.head_(cfg.evaluation.size)
subtrain_dataset.training_(False)
train_dl = _create_dataloader(train_dataset, training=True)
val_dl = _create_dataloader(val_dataset, training=False)
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
_logger.info(str(train_dataset.symmap))
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
assert isinstance(subtrain_dl.dataset, Dataset)
return train_dl, subtrain_dl, val_dl
"""
if __name__ == "__main__":
create_dataset_hdf5()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
sample = train_dl.dataset[0]
print(sample)
"""