# todo: clean this mess up import copy import h5py import json import logging import numpy as np import os import random import torch import math from .config import cfg from collections import defaultdict from functools import cache, cached_property from itertools import groupby, zip_longest from pathlib import Path from typing import Any from torch import Tensor from torch.utils.data import DataLoader, Dataset as _Dataset import torchvision.transforms as transforms from tqdm.auto import tqdm from PIL import Image # torch.multiprocessing.set_sharing_strategy("file_system") _logger = logging.getLogger(__name__) @cache def get_symmap(): return { " ": 0, "": 1, "": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 } @cache def _get_symbols( content ): content = content.replace("O", "0") return [f""] + [ p for p in content ] + [f""] class Dataset(_Dataset): def __init__( self, paths, width=300, height=80, symmap=get_symmap(), training=False, ): super().__init__() self._head = None self.paths = paths self.width = width self.height = height self.symmap = symmap self.training = training self.transform = transforms.Compose([ transforms.Resize((self.height, self.width)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) @cached_property def symbols(self): return sorted(set().union(*[_get_symbols(path.stem) for path in self.paths])) def __getitem__(self, index): path = self.paths[index] try: text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8) except Exception as e: print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem) raise e image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) return dict( index=index, path=path, image=image, text=text, ) def head_(self, n): self._head = n def training_(self, value): self.training = value def __len__(self): return min(len(self.paths), self._head or len(self.paths)) def pin_memory(self): self.text = self.text.pin_memory() self.image = self.image.pin_memory() return self def collate_fn(samples: list[dict]): batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]} return batch def _seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) def _create_dataloader(dataset, training): return DataLoader( dataset=dataset, batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, shuffle=True, # training drop_last=training, num_workers=cfg.dataset.workers, collate_fn=collate_fn, persistent_workers=cfg.dataset.workers > 0, pin_memory=False, # True, worker_init_fn=_seed_worker, ) def _load_train_val_paths( val_ratio=0.1 ): paths = [] train_paths = [] val_paths = [] print(cfg.dataset.training) for data_dir in cfg.dataset.training: paths.extend(data_dir.rglob("*.png")) if len(paths) > 0: random.seed(0) random.shuffle(paths) train_paths.extend(paths) if len(cfg.dataset.validation) == 0: val_len = math.floor(len(train_paths) * val_ratio) train_len = math.floor(len(train_paths) * (1 - val_ratio)) print(val_len, train_len) val_paths = train_paths[:-val_len] train_paths = train_paths[:train_len] else: for data_dir in cfg.dataset.validation: paths.extend(data_dir.rglob("*.png")) if len(paths) > 0: random.seed(0) random.shuffle(paths) val_paths.extend(paths) train_paths, val_paths = map(sorted, [train_paths, val_paths]) if len(train_paths) == 0: raise RuntimeError(f"Failed to find any .png file in {cfg.dataset.training}.") # to get it to shut up if len(val_paths) == 0: val_paths = [ train_paths[0] ] return train_paths, val_paths @cfg.diskcache() def create_datasets(): train_paths, val_paths = _load_train_val_paths() train_dataset = Dataset( train_paths, training=True, ) val_dataset = Dataset( val_paths, train_dataset.symmap, ) val_dataset.head_(cfg.evaluation.size) return train_dataset, val_dataset def create_train_val_dataloader(): train_dataset, val_dataset = create_datasets() subtrain_dataset = copy.deepcopy(train_dataset) subtrain_dataset.head_(cfg.evaluation.size) #subtrain_dataset.training_(False) train_dl = _create_dataloader(train_dataset, training=True) val_dl = _create_dataloader(val_dataset, training=False) subtrain_dl = _create_dataloader(subtrain_dataset, training=False) _logger.info(str(train_dataset.symmap)) _logger.info(f"#samples (train): {len(train_dataset)}.") _logger.info(f"#samples (val): {len(val_dataset)}.") _logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.") assert isinstance(subtrain_dl.dataset, Dataset) return train_dl, subtrain_dl, val_dl if __name__ == "__main__": create_dataset_hdf5() train_dl, subtrain_dl, val_dl = create_train_val_dataloader() sample = train_dl.dataset[0] print(sample)