221 lines
5.6 KiB
Python
Executable File
221 lines
5.6 KiB
Python
Executable File
# todo: clean this mess up
|
|
|
|
import copy
|
|
# import h5py
|
|
import json
|
|
import logging
|
|
import numpy as np
|
|
import os
|
|
import random
|
|
import torch
|
|
import math
|
|
|
|
from .config import cfg
|
|
|
|
from collections import defaultdict
|
|
from functools import cache, cached_property
|
|
from itertools import groupby, zip_longest
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from torch import Tensor
|
|
from torch.utils.data import DataLoader, Dataset as _Dataset
|
|
import torchvision.transforms as transforms
|
|
from tqdm.auto import tqdm
|
|
|
|
from PIL import Image
|
|
|
|
# torch.multiprocessing.set_sharing_strategy("file_system")
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
@cache
|
|
def get_symmap():
|
|
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
|
|
|
|
@cache
|
|
def _get_symbols( content ):
|
|
content = content.replace("O", "0")
|
|
return [f"<s>"] + [ p for p in content ] + [f"</s>"]
|
|
|
|
class Dataset(_Dataset):
|
|
def __init__(
|
|
self,
|
|
paths,
|
|
width=300,
|
|
height=80,
|
|
|
|
symmap=get_symmap(),
|
|
training=False,
|
|
):
|
|
super().__init__()
|
|
|
|
self._head = None
|
|
|
|
self.paths = paths
|
|
self.width = width
|
|
self.height = height
|
|
|
|
self.symmap = symmap
|
|
self.training = training
|
|
|
|
self.transform = transforms.Compose([
|
|
#transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
|
|
@cached_property
|
|
def symbols(self):
|
|
return sorted(set().union(*[_get_symbols(path.stem) for path in self.paths]))
|
|
|
|
|
|
def __getitem__(self, index):
|
|
path = self.paths[index]
|
|
|
|
# stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here
|
|
try:
|
|
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
|
|
except Exception as e:
|
|
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
|
|
raise e
|
|
|
|
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB
|
|
|
|
return dict(
|
|
index=index,
|
|
path=path,
|
|
image=image,
|
|
text=text,
|
|
)
|
|
|
|
def head_(self, n):
|
|
self._head = n
|
|
|
|
def training_(self, value):
|
|
self.training = value
|
|
|
|
def __len__(self):
|
|
return min(len(self.paths), self._head or len(self.paths))
|
|
|
|
def pin_memory(self):
|
|
self.text = self.text.pin_memory()
|
|
self.image = self.image.pin_memory()
|
|
return self
|
|
|
|
|
|
def collate_fn(samples: list[dict]):
|
|
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
|
|
return batch
|
|
|
|
|
|
def _seed_worker(worker_id):
|
|
worker_seed = torch.initial_seed() % 2**32
|
|
np.random.seed(worker_seed)
|
|
random.seed(worker_seed)
|
|
|
|
|
|
def _create_dataloader(dataset, training):
|
|
return DataLoader(
|
|
dataset=dataset,
|
|
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
|
shuffle=True, # training
|
|
drop_last=training,
|
|
num_workers=cfg.dataset.workers,
|
|
collate_fn=collate_fn,
|
|
persistent_workers=cfg.dataset.workers > 0,
|
|
pin_memory=False, # True,
|
|
worker_init_fn=_seed_worker,
|
|
)
|
|
|
|
def _load_train_val_paths( val_ratio=0.1 ):
|
|
paths = []
|
|
train_paths = []
|
|
val_paths = []
|
|
|
|
print(cfg.dataset.training)
|
|
for data_dir in cfg.dataset.training:
|
|
paths.extend(data_dir.rglob("*.png"))
|
|
|
|
if len(paths) > 0:
|
|
random.seed(0)
|
|
random.shuffle(paths)
|
|
train_paths.extend(paths)
|
|
|
|
if len(cfg.dataset.validation) == 0:
|
|
val_len = math.floor(len(train_paths) * val_ratio)
|
|
train_len = math.floor(len(train_paths) * (1 - val_ratio))
|
|
|
|
print(val_len, train_len)
|
|
|
|
val_paths = train_paths[:-val_len]
|
|
train_paths = train_paths[:train_len]
|
|
else:
|
|
for data_dir in cfg.dataset.validation:
|
|
paths.extend(data_dir.rglob("*.png"))
|
|
|
|
if len(paths) > 0:
|
|
random.seed(0)
|
|
random.shuffle(paths)
|
|
val_paths.extend(paths)
|
|
|
|
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
|
|
|
if len(train_paths) == 0:
|
|
raise RuntimeError(f"Failed to find any .png file in {cfg.dataset.training}.")
|
|
# to get it to shut up
|
|
if len(val_paths) == 0:
|
|
val_paths = [ train_paths[0] ]
|
|
|
|
return train_paths, val_paths
|
|
|
|
@cfg.diskcache()
|
|
def create_datasets():
|
|
train_paths, val_paths = _load_train_val_paths()
|
|
|
|
train_dataset = Dataset(
|
|
train_paths,
|
|
training=True,
|
|
)
|
|
|
|
val_dataset = Dataset(
|
|
val_paths,
|
|
train_dataset.symmap,
|
|
)
|
|
|
|
val_dataset.head_(cfg.evaluation.size)
|
|
|
|
return train_dataset, val_dataset
|
|
|
|
|
|
def create_train_val_dataloader():
|
|
train_dataset, val_dataset = create_datasets()
|
|
|
|
subtrain_dataset = copy.deepcopy(train_dataset)
|
|
subtrain_dataset.head_(cfg.evaluation.size)
|
|
subtrain_dataset.training_(False)
|
|
|
|
train_dl = _create_dataloader(train_dataset, training=True)
|
|
val_dl = _create_dataloader(val_dataset, training=False)
|
|
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
|
|
|
|
_logger.info(str(train_dataset.symmap))
|
|
|
|
|
|
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
|
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
|
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
|
|
|
|
assert isinstance(subtrain_dl.dataset, Dataset)
|
|
|
|
return train_dl, subtrain_dl, val_dl
|
|
|
|
"""
|
|
if __name__ == "__main__":
|
|
create_dataset_hdf5()
|
|
|
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
|
sample = train_dl.dataset[0]
|
|
print(sample)
|
|
"""
|