# todo: clean this mess up import copy import h5py import json import logging import numpy as np import os import random import torch import itertools from .config import cfg from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler from .utils.distributed import global_rank, local_rank, world_size from .utils.io import torch_save, torch_load from collections import defaultdict from functools import cache, cached_property from itertools import groupby, zip_longest from pathlib import Path from typing import Any from torch import Tensor from torch.utils.data import DataLoader, Dataset as _Dataset from torch.utils.data.distributed import DistributedSampler from torch.nn.utils.rnn import pad_sequence from PIL import Image, ImageDraw import torchvision.transforms as transforms from tqdm.auto import tqdm # torch.multiprocessing.set_sharing_strategy("file_system") _logger = logging.getLogger(__name__) # to-do: clean up this symmap mess def get_symmap(): return cfg.tokenizer.get_vocab() def tokenize( s ): if isinstance( s, list ): s = "".join( s ) return cfg.tokenizer.encode( s ) """ def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) def _get_hdf5_path(path): # to-do: better validation return str(path) def _get_hdf5_paths( data_dir, type="training", validate=False ): data_dir = str(data_dir) key = f"/{type}/{_get_hdf5_path(data_dir)}" return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items()] if key in cfg.hdf5 else [] def _get_paths_of_extensions( path, validate=False ): if isinstance(path, str): path = Path(path) return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else [] def _interleaved_reorder(l, fn): groups = defaultdict(list) for e in l: groups[fn(e)].append(e) groups = {k: groups[k] for k in sorted(groups)} for interleaved in zip_longest(*groups.values()): for value in interleaved: if value is not None: yield value """ class Dataset(_Dataset): def __init__( self, paths, width=300, height=80, stacks=0, symmap=get_symmap(), training=False, ): super().__init__() self._head = None self.sampler = None self.width = width self.height = height self.stacks = stacks self.paths = paths self.image_dtype = cfg.trainer.dtype self.symmap = symmap self.training = training self.dataset_type = "training" if self.training else "validation" self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation self.transform = transforms.Compose([ transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # to-do: do not do validation if there's nothing in the validation # this just makes it be happy if len(self.dataset) == 0: self.dataset = cfg.dataset.training # split dataset accordingly per GPU if cfg.distributed and self.training: self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ] if len(self.paths) == 0: raise ValueError(f"No valid path is found for {self.dataset_type}") @cached_property def sampler_state_dict_path(self): return cfg.rel_path / f"sampler.rank{global_rank()}.pt" def save_state_dict(self, path = None): """ if path is None: path = self.sampler_state_dict_path if self.sampler is not None: state_dict = self.sampler.get_state() elif self.samplers is not None: state_dict = { "samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() }, } torch_save(state_dict, path) """ return def load_state_dict(self, path = None): """ if path is None: path = self.sampler_state_dict_path if not path.exists(): return state_dict = torch_load(path) if self.sampler is not None: state_dict = self.sampler.set_state(state_dict) else: for name, sampler in state_dict["samplers"].items(): if name not in self.samplers: continue self.samplers[name].set_state( sampler ) """ return def __getitem__(self, index): path = self.paths[index] tokens = tokenize( path.stem.upper() ) text = torch.tensor( tokens ).to(dtype=torch.uint8) image = Image.open(path).convert('RGB') width, height = image.size image = self.transform(image).to(dtype=self.image_dtype) # resnet has to be RGB return dict( index=index, path=path, image=image, text=text, ) def head_(self, n): self._head = n def training_(self, value): self.training = value def __len__(self): return min(len(self.paths), self._head or len(self.paths)) def collate_fn(samples: list[dict]): batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]} return batch def _seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) def _create_dataloader(dataset, training): kwargs = dict( shuffle=True, batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, drop_last=training, sampler=dataset.sampler if training else None, ) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict( batch_sampler=dataset.sampler, ) return DataLoader( dataset=dataset, num_workers=cfg.dataset.workers, collate_fn=collate_fn, persistent_workers=cfg.dataset.workers > 1, pin_memory=False, worker_init_fn=_seed_worker, **kwargs, ) def _load_train_val_paths( val_ratio=0.1 ): paths = [] train_paths = [] val_paths = [] for data_dir in cfg.dataset.training: paths.extend(data_dir.rglob("*.jpg")) paths.extend(data_dir.rglob("*.png")) if len(paths) > 0: random.seed(0) random.shuffle(paths) train_paths.extend(paths) if len(cfg.dataset.validation) == 0: val_len = math.floor(len(train_paths) * val_ratio) train_len = math.floor(len(train_paths) * (1 - val_ratio)) val_paths = train_paths[:-val_len] train_paths = train_paths[:train_len] else: paths = [] for data_dir in cfg.dataset.validation: paths.extend(data_dir.rglob("*.jpg")) paths.extend(data_dir.rglob("*.png")) if len(paths) > 0: random.seed(0) random.shuffle(paths) val_paths.extend(paths) train_paths, val_paths = map(sorted, [train_paths, val_paths]) if len(train_paths) == 0: raise RuntimeError(f"Failed to find any .png file in {cfg.dataset.training}.") # to get it to shut up if len(val_paths) == 0: val_paths = [ train_paths[0] ] return train_paths, val_paths def create_datasets(): train_paths, val_paths = _load_train_val_paths() train_dataset = Dataset( train_paths, training=True, ) val_dataset = Dataset( val_paths, train_dataset.symmap, ) val_dataset.head_(cfg.evaluation.size) return train_dataset, val_dataset def create_train_val_dataloader(): train_dataset, val_dataset = create_datasets() # deepcopy is slow subtrain_dataset = copy.deepcopy(train_dataset) subtrain_dataset.head_(cfg.evaluation.size) subtrain_dataset.training_(False) train_dl = _create_dataloader(train_dataset, training=True) val_dl = _create_dataloader(val_dataset, training=False) subtrain_dl = _create_dataloader(subtrain_dataset, training=False) _logger.info(str(train_dataset.symmap)) _logger.info(f"#samples (train): {len(train_dataset)}.") _logger.info(f"#samples (val): {len(val_dataset)}.") _logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.") assert isinstance(subtrain_dl.dataset, Dataset) return train_dl, subtrain_dl, val_dl # parse dataset into better to sample metadata """ def create_dataset_metadata( skip_existing=True ): symmap = get_symmap() root = str(cfg.data_dir) metadata_root = str(cfg.metadata_dir) cfg.metadata_dir.mkdir(parents=True, exist_ok=True) def add( dir, type="training", audios=True, texts=True ): name = str(dir) name = name.replace(root, "") speaker_name = name metadata_path = Path(f"{metadata_root}/{speaker_name}.json") metadata_path.parents[0].mkdir(parents=True, exist_ok=True) try: metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) except Exception as e: metadata = {} if not os.path.isdir(f'{root}/{name}/'): return # tqdm.write(f'{root}/{name}') files = os.listdir(f'{root}/{name}/') # grab IDs for every file ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files } wrote = False for id in tqdm(ids, desc=f"Processing {name}"): try: quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}') if audios and not quant_path.exists(): continue key = f'{type}/{speaker_name}/{id}' if skip_existing and id in metadata: continue wrote = True if id not in metadata: metadata[id] = {} utterance_metadata = {} if audios: # ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt dac = np.load(quant_path, allow_pickle=True)[()] qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16) if "text" in dac["metadata"]: utterance_metadata["text"] = dac["metadata"]["text"] if "phonemes" in dac["metadata"]: utterance_metadata["phonemes"] = dac["metadata"]["phonemes"] if "language" in dac["metadata"]: utterance_metadata["language"] = dac["metadata"]["language"] if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]: utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"] for k, v in utterance_metadata.items(): metadata[id][k] = v except Exception as e: tqdm.write(f'Error while processing {id}: {e}') if wrote: with open(str(metadata_path), "w", encoding="utf-8") as f: f.write( json.dumps( metadata ) ) # training for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"): add( data_dir, type="training" ) # validation for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'): add( data_dir, type="validation" ) # noise for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'): add( data_dir, type="noise", texts=False ) # parse yaml to create an hdf5 file def create_dataset_hdf5( skip_existing=True ): cfg.dataset.use_hdf5 = True cfg.load_hdf5(write=True) hf = cfg.hdf5 symmap = get_symmap() root = str(cfg.data_dir) metadata_root = str(cfg.metadata_dir) def add( dir, type="training", audios=True, texts=True ): name = str(dir) name = name.replace(root, "") # yucky speaker_name = name if "LibriTTS-R" in speaker_name: speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox") metadata_path = Path(f"{metadata_root}/{speaker_name}.json") metadata_path.parents[0].mkdir(parents=True, exist_ok=True) metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) if not os.path.isdir(f'{root}/{name}/'): return files = os.listdir(f'{root}/{name}/') # grab IDs for every file ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files } for id in tqdm(ids, desc=f"Processing {name}"): try: quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True if not quant_exists: continue key = f'{type}/{speaker_name}/{id}' if skip_existing and key in hf: continue group = hf.create_group(key) if key not in hf else hf[key] if id not in metadata: metadata[id] = {} utterance_metadata = {} # audio if audios: dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16) if "text" in dac["metadata"]: utterance_metadata["text"] = dac["metadata"]["text"] if "phonemes" in dac["metadata"]: utterance_metadata["phonemes"] = dac["metadata"]["phonemes"] if "language" in dac["metadata"]: utterance_metadata["language"] = dac["metadata"]["language"] if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]: utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"] if "audio" not in group: group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf') # text if texts: if not utterance_metadata and text_exists: utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) phn = "".join(utterance_metadata["phonemes"]) phn = cfg.tokenizer.encode(phn) phn = np.array(phn).astype(np.uint8) if "text" not in group: group.create_dataset('text', data=phn, compression='lzf') for k, v in utterance_metadata.items(): group.attrs[k] = v metadata[id][k] = v except Exception as e: tqdm.write(f'Error while processing {id}: {e}') with open(str(metadata_path), "w", encoding="utf-8") as f: f.write( json.dumps( metadata ) ) # training for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"): add( data_dir, type="training" ) # validation for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'): add( data_dir, type="validation" ) # noise for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'): add( data_dir, type="noise", texts=False ) # write symmap if "symmap" in hf: del hf['symmap'] hf.create_dataset('symmap', data=json.dumps(symmap)) hf.close() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser("Save trained model to path.") parser.add_argument("--action", type=str) parser.add_argument("--tasks", type=str) args, unknown = parser.parse_known_args() task = args.action cfg.dataset.workers = 1 if args.action == "hdf5": create_dataset_hdf5() elif args.action == "list-dataset": dataset = [] for group in os.listdir(cfg.data_dir): for name in os.listdir(cfg.data_dir / group): if len(os.listdir(cfg.data_dir / group / name)) == 0: continue dataset.append(f'{group}/{name}') _logger.info(json.dumps(dataset)) elif args.action == "metadata": create_dataset_metadata() elif args.action == "sample": train_dl, subtrain_dl, val_dl = create_train_val_dataloader() samples = { "training": [ next(iter(train_dl)), next(iter(train_dl)) ], "evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ], #"validation": [ next(iter(val_dl)), next(iter(val_dl)) ], } Path("./data/sample-test/").mkdir(parents=True, exist_ok=True) for k, v in samples.items(): for i in range(len(v)): for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."): try: decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" ) except Exception as e: _logger.info(f"Error while decoding prom {k}.{i}.{j}.wav: {str(e)}") try: decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" ) except Exception as e: _logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}") v[i]['proms'][j] = v[i]['proms'][j].shape v[i]['resps'][j] = v[i]['resps'][j].shape for k, v in samples.items(): for i in range(len(v)): _logger.info(f'{k}[{i}]: {v[i]}') elif args.action == "validate": train_dl, subtrain_dl, val_dl = create_train_val_dataloader() missing = set() for i in range(len( train_dl.dataset )): batch = train_dl.dataset[i] text = batch['text'] phonemes = batch['metadata']['phonemes'] decoded = [ cfg.tokenizer.decode(token) for token in text[1:-1] ] for i, token in enumerate(decoded): if token != "": continue phone = phonemes[i] _logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" ) missing |= set([phone]) _logger.info( f"Missing tokens: {missing}" ) elif args.action == "tasks": index = 0 cfg.dataset.tasks_list = args.tasks.split(",") train_dl, subtrain_dl, val_dl = create_train_val_dataloader() batch = next(iter(train_dl)) for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]): if task not in cfg.dataset.tasks_list: continue _logger.info( f'{text} {task} {cfg.model.resp_levels}') _logger.info( f'{proms.shape} {resps.shape}' ) tokens = 0 tokens += sum([ text.shape[0] for text in batch["text"] ]) tokens += sum([ resps.shape[0] for resps in batch["resps"] ]) _logger.info( f'{tokens}' ) decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" ) decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" ) break """