resnet-classifier/image_classifier/data.py

589 lines
17 KiB
Python
Executable File

# todo: clean this mess up
import copy
import h5py
import json
import logging
import numpy as np
import os
import random
import torch
import itertools
from .config import cfg
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
from .utils.distributed import global_rank, local_rank, world_size
from .utils.io import torch_save, torch_load
from collections import defaultdict
from functools import cache, cached_property
from itertools import groupby, zip_longest
from pathlib import Path
from typing import Any
from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.utils.rnn import pad_sequence
from PIL import Image, ImageDraw
import torchvision.transforms as transforms
from tqdm.auto import tqdm
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__)
# to-do: clean up this symmap mess
def get_symmap():
return cfg.tokenizer.get_vocab()
def tokenize( s ):
if isinstance( s, list ):
s = "".join( s )
return cfg.tokenizer.encode( s )
"""
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_hdf5_path(path):
# to-do: better validation
return str(path)
def _get_hdf5_paths( data_dir, type="training", validate=False ):
data_dir = str(data_dir)
key = f"/{type}/{_get_hdf5_path(data_dir)}"
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items()] if key in cfg.hdf5 else []
def _get_paths_of_extensions( path, validate=False ):
if isinstance(path, str):
path = Path(path)
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
def _interleaved_reorder(l, fn):
groups = defaultdict(list)
for e in l:
groups[fn(e)].append(e)
groups = {k: groups[k] for k in sorted(groups)}
for interleaved in zip_longest(*groups.values()):
for value in interleaved:
if value is not None:
yield value
"""
class Dataset(_Dataset):
def __init__(
self,
paths,
width=300,
height=80,
stacks=0,
symmap=get_symmap(),
training=False,
):
super().__init__()
self._head = None
self.sampler = None
self.width = width
self.height = height
self.stacks = stacks
self.paths = paths
self.image_dtype = cfg.trainer.dtype
self.symmap = symmap
self.training = training
self.dataset_type = "training" if self.training else "validation"
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
self.transform = transforms.Compose([
transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
if len(self.dataset) == 0:
self.dataset = cfg.dataset.training
# split dataset accordingly per GPU
if cfg.distributed and self.training:
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
if len(self.paths) == 0:
raise ValueError(f"No valid path is found for {self.dataset_type}")
@cached_property
def sampler_state_dict_path(self):
return cfg.rel_path / f"sampler.rank{global_rank()}.pt"
def save_state_dict(self, path = None):
"""
if path is None:
path = self.sampler_state_dict_path
if self.sampler is not None:
state_dict = self.sampler.get_state()
elif self.samplers is not None:
state_dict = {
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
}
torch_save(state_dict, path)
"""
return
def load_state_dict(self, path = None):
"""
if path is None:
path = self.sampler_state_dict_path
if not path.exists():
return
state_dict = torch_load(path)
if self.sampler is not None:
state_dict = self.sampler.set_state(state_dict)
else:
for name, sampler in state_dict["samplers"].items():
if name not in self.samplers:
continue
self.samplers[name].set_state( sampler )
"""
return
def __getitem__(self, index):
path = self.paths[index]
text = path.stem.upper()
image = Image.open(path).convert('RGB')
width, height = image.size
text = torch.tensor( tokenize( text ) ).to(dtype=torch.uint8)
image = self.transform(image).to(dtype=self.image_dtype)
return dict(
index=index,
path=path,
image=image,
text=text,
)
def head_(self, n):
self._head = n
def training_(self, value):
self.training = value
def __len__(self):
return min(len(self.paths), self._head or len(self.paths))
def collate_fn(samples: list[dict]):
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
return batch
def _seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def _create_dataloader(dataset, training):
kwargs = dict(
shuffle=True,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
drop_last=training,
sampler=dataset.sampler if training else None,
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
batch_sampler=dataset.sampler,
)
return DataLoader(
dataset=dataset,
num_workers=cfg.dataset.workers,
collate_fn=collate_fn,
persistent_workers=cfg.dataset.workers > 1,
pin_memory=False,
worker_init_fn=_seed_worker,
**kwargs,
)
def _load_train_val_paths( val_ratio=0.1 ):
train_paths = []
val_paths = []
for data_dir in cfg.dataset.training:
train_paths.extend(data_dir.rglob("*.jpg"))
train_paths.extend(data_dir.rglob("*.png"))
if len(train_paths) > 0:
random.seed(0)
random.shuffle(train_paths)
for data_dir in cfg.dataset.validation:
val_paths.extend(data_dir.rglob("*.jpg"))
val_paths.extend(data_dir.rglob("*.png"))
if len(val_paths) > 0:
random.seed(0)
random.shuffle(val_paths)
train_paths, val_paths = map(sorted, [train_paths, val_paths])
if len(train_paths) == 0:
raise RuntimeError(f"Failed to find any .png file in {cfg.dataset.training}.")
# to get it to shut up
if len(val_paths) == 0:
val_paths = [ train_paths[0] ]
return train_paths, val_paths
def create_datasets():
train_paths, val_paths = _load_train_val_paths()
train_dataset = Dataset(
train_paths,
training=True,
)
val_dataset = Dataset(
val_paths,
train_dataset.symmap,
)
val_dataset.head_(cfg.evaluation.size)
return train_dataset, val_dataset
def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets()
# deepcopy is slow
subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.head_(cfg.evaluation.size)
subtrain_dataset.training_(False)
train_dl = _create_dataloader(train_dataset, training=True)
val_dl = _create_dataloader(val_dataset, training=False)
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
_logger.info(str(train_dataset.symmap))
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
assert isinstance(subtrain_dl.dataset, Dataset)
return train_dl, subtrain_dl, val_dl
# parse dataset into better to sample metadata
"""
def create_dataset_metadata( skip_existing=True ):
symmap = get_symmap()
root = str(cfg.data_dir)
metadata_root = str(cfg.metadata_dir)
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
def add( dir, type="training", audios=True, texts=True ):
name = str(dir)
name = name.replace(root, "")
speaker_name = name
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
try:
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
except Exception as e:
metadata = {}
if not os.path.isdir(f'{root}/{name}/'):
return
# tqdm.write(f'{root}/{name}')
files = os.listdir(f'{root}/{name}/')
# grab IDs for every file
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
wrote = False
for id in tqdm(ids, desc=f"Processing {name}"):
try:
quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}')
if audios and not quant_path.exists():
continue
key = f'{type}/{speaker_name}/{id}'
if skip_existing and id in metadata:
continue
wrote = True
if id not in metadata:
metadata[id] = {}
utterance_metadata = {}
if audios:
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
dac = np.load(quant_path, allow_pickle=True)[()]
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
if "text" in dac["metadata"]:
utterance_metadata["text"] = dac["metadata"]["text"]
if "phonemes" in dac["metadata"]:
utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
if "language" in dac["metadata"]:
utterance_metadata["language"] = dac["metadata"]["language"]
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
for k, v in utterance_metadata.items():
metadata[id][k] = v
except Exception as e:
tqdm.write(f'Error while processing {id}: {e}')
if wrote:
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
# training
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
add( data_dir, type="training" )
# validation
for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
add( data_dir, type="validation" )
# noise
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
add( data_dir, type="noise", texts=False )
# parse yaml to create an hdf5 file
def create_dataset_hdf5( skip_existing=True ):
cfg.dataset.use_hdf5 = True
cfg.load_hdf5(write=True)
hf = cfg.hdf5
symmap = get_symmap()
root = str(cfg.data_dir)
metadata_root = str(cfg.metadata_dir)
def add( dir, type="training", audios=True, texts=True ):
name = str(dir)
name = name.replace(root, "")
# yucky
speaker_name = name
if "LibriTTS-R" in speaker_name:
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
if not os.path.isdir(f'{root}/{name}/'):
return
files = os.listdir(f'{root}/{name}/')
# grab IDs for every file
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
for id in tqdm(ids, desc=f"Processing {name}"):
try:
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
if not quant_exists:
continue
key = f'{type}/{speaker_name}/{id}'
if skip_existing and key in hf:
continue
group = hf.create_group(key) if key not in hf else hf[key]
if id not in metadata:
metadata[id] = {}
utterance_metadata = {}
# audio
if audios:
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
if "text" in dac["metadata"]:
utterance_metadata["text"] = dac["metadata"]["text"]
if "phonemes" in dac["metadata"]:
utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
if "language" in dac["metadata"]:
utterance_metadata["language"] = dac["metadata"]["language"]
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
if "audio" not in group:
group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
# text
if texts:
if not utterance_metadata and text_exists:
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
phn = "".join(utterance_metadata["phonemes"])
phn = cfg.tokenizer.encode(phn)
phn = np.array(phn).astype(np.uint8)
if "text" not in group:
group.create_dataset('text', data=phn, compression='lzf')
for k, v in utterance_metadata.items():
group.attrs[k] = v
metadata[id][k] = v
except Exception as e:
tqdm.write(f'Error while processing {id}: {e}')
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
# training
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
add( data_dir, type="training" )
# validation
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
add( data_dir, type="validation" )
# noise
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
add( data_dir, type="noise", texts=False )
# write symmap
if "symmap" in hf:
del hf['symmap']
hf.create_dataset('symmap', data=json.dumps(symmap))
hf.close()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--action", type=str)
parser.add_argument("--tasks", type=str)
args, unknown = parser.parse_known_args()
task = args.action
cfg.dataset.workers = 1
if args.action == "hdf5":
create_dataset_hdf5()
elif args.action == "list-dataset":
dataset = []
for group in os.listdir(cfg.data_dir):
for name in os.listdir(cfg.data_dir / group):
if len(os.listdir(cfg.data_dir / group / name)) == 0:
continue
dataset.append(f'{group}/{name}')
_logger.info(json.dumps(dataset))
elif args.action == "metadata":
create_dataset_metadata()
elif args.action == "sample":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
samples = {
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
#"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
}
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
for k, v in samples.items():
for i in range(len(v)):
for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
try:
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
except Exception as e:
_logger.info(f"Error while decoding prom {k}.{i}.{j}.wav: {str(e)}")
try:
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
except Exception as e:
_logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
v[i]['proms'][j] = v[i]['proms'][j].shape
v[i]['resps'][j] = v[i]['resps'][j].shape
for k, v in samples.items():
for i in range(len(v)):
_logger.info(f'{k}[{i}]: {v[i]}')
elif args.action == "validate":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
missing = set()
for i in range(len( train_dl.dataset )):
batch = train_dl.dataset[i]
text = batch['text']
phonemes = batch['metadata']['phonemes']
decoded = [ cfg.tokenizer.decode(token) for token in text[1:-1] ]
for i, token in enumerate(decoded):
if token != "<unk>":
continue
phone = phonemes[i]
_logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
missing |= set([phone])
_logger.info( f"Missing tokens: {missing}" )
elif args.action == "tasks":
index = 0
cfg.dataset.tasks_list = args.tasks.split(",")
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
batch = next(iter(train_dl))
for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
if task not in cfg.dataset.tasks_list:
continue
_logger.info( f'{text} {task} {cfg.model.resp_levels}')
_logger.info( f'{proms.shape} {resps.shape}' )
tokens = 0
tokens += sum([ text.shape[0] for text in batch["text"] ])
tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
_logger.info( f'{tokens}' )
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
break
"""
if __name__ == "__main__":
...