598 lines
17 KiB
Python
Executable File
598 lines
17 KiB
Python
Executable File
# todo: clean this mess up
|
|
|
|
import copy
|
|
import h5py
|
|
import json
|
|
import logging
|
|
import numpy as np
|
|
import os
|
|
import random
|
|
import torch
|
|
import itertools
|
|
|
|
from .config import cfg
|
|
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
|
from .utils.distributed import global_rank, local_rank, world_size
|
|
from .utils.io import torch_save, torch_load
|
|
|
|
from collections import defaultdict
|
|
from functools import cache, cached_property
|
|
from itertools import groupby, zip_longest
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from torch import Tensor
|
|
from torch.utils.data import DataLoader, Dataset as _Dataset
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
from PIL import Image, ImageDraw
|
|
import torchvision.transforms as transforms
|
|
|
|
from tqdm.auto import tqdm
|
|
# torch.multiprocessing.set_sharing_strategy("file_system")
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
# to-do: clean up this symmap mess
|
|
def get_symmap():
|
|
return cfg.tokenizer.get_vocab()
|
|
|
|
def tokenize( s ):
|
|
if isinstance( s, list ):
|
|
s = "".join( s )
|
|
return cfg.tokenizer.encode( s )
|
|
|
|
"""
|
|
def _replace_file_extension(path, suffix):
|
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
|
|
|
def _get_hdf5_path(path):
|
|
# to-do: better validation
|
|
return str(path)
|
|
|
|
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
|
data_dir = str(data_dir)
|
|
|
|
key = f"/{type}/{_get_hdf5_path(data_dir)}"
|
|
|
|
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items()] if key in cfg.hdf5 else []
|
|
|
|
def _get_paths_of_extensions( path, validate=False ):
|
|
if isinstance(path, str):
|
|
path = Path(path)
|
|
|
|
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
|
|
|
|
def _interleaved_reorder(l, fn):
|
|
groups = defaultdict(list)
|
|
for e in l:
|
|
groups[fn(e)].append(e)
|
|
groups = {k: groups[k] for k in sorted(groups)}
|
|
for interleaved in zip_longest(*groups.values()):
|
|
for value in interleaved:
|
|
if value is not None:
|
|
yield value
|
|
"""
|
|
|
|
class Dataset(_Dataset):
|
|
def __init__(
|
|
self,
|
|
paths,
|
|
width=300,
|
|
height=80,
|
|
stacks=0,
|
|
|
|
symmap=get_symmap(),
|
|
training=False,
|
|
):
|
|
super().__init__()
|
|
self._head = None
|
|
self.sampler = None
|
|
self.width = width
|
|
self.height = height
|
|
self.stacks = stacks
|
|
|
|
self.paths = paths
|
|
self.image_dtype = cfg.trainer.dtype
|
|
self.symmap = symmap
|
|
|
|
self.training = training
|
|
self.dataset_type = "training" if self.training else "validation"
|
|
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
|
|
|
self.transform = transforms.Compose([
|
|
transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
|
|
# to-do: do not do validation if there's nothing in the validation
|
|
# this just makes it be happy
|
|
if len(self.dataset) == 0:
|
|
self.dataset = cfg.dataset.training
|
|
|
|
# split dataset accordingly per GPU
|
|
if cfg.distributed and self.training:
|
|
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
|
|
|
|
if len(self.paths) == 0:
|
|
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
|
|
|
@cached_property
|
|
def sampler_state_dict_path(self):
|
|
return cfg.rel_path / f"sampler.rank{global_rank()}.pt"
|
|
|
|
def save_state_dict(self, path = None):
|
|
"""
|
|
if path is None:
|
|
path = self.sampler_state_dict_path
|
|
if self.sampler is not None:
|
|
state_dict = self.sampler.get_state()
|
|
elif self.samplers is not None:
|
|
state_dict = {
|
|
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
|
|
}
|
|
torch_save(state_dict, path)
|
|
"""
|
|
return
|
|
|
|
def load_state_dict(self, path = None):
|
|
"""
|
|
if path is None:
|
|
path = self.sampler_state_dict_path
|
|
|
|
if not path.exists():
|
|
return
|
|
|
|
state_dict = torch_load(path)
|
|
if self.sampler is not None:
|
|
state_dict = self.sampler.set_state(state_dict)
|
|
else:
|
|
for name, sampler in state_dict["samplers"].items():
|
|
if name not in self.samplers:
|
|
continue
|
|
self.samplers[name].set_state( sampler )
|
|
"""
|
|
return
|
|
|
|
def __getitem__(self, index):
|
|
path = self.paths[index]
|
|
tokens = tokenize( path.stem.upper() )
|
|
text = torch.tensor( tokens ).to(dtype=torch.uint8)
|
|
|
|
image = Image.open(path).convert('RGB')
|
|
width, height = image.size
|
|
|
|
image = self.transform(image).to(dtype=self.image_dtype) # resnet has to be RGB
|
|
|
|
return dict(
|
|
index=index,
|
|
path=path,
|
|
image=image,
|
|
text=text,
|
|
)
|
|
|
|
def head_(self, n):
|
|
self._head = n
|
|
|
|
def training_(self, value):
|
|
self.training = value
|
|
|
|
def __len__(self):
|
|
return min(len(self.paths), self._head or len(self.paths))
|
|
|
|
|
|
def collate_fn(samples: list[dict]):
|
|
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
|
|
return batch
|
|
|
|
|
|
def _seed_worker(worker_id):
|
|
worker_seed = torch.initial_seed() % 2**32
|
|
np.random.seed(worker_seed)
|
|
random.seed(worker_seed)
|
|
|
|
|
|
def _create_dataloader(dataset, training):
|
|
kwargs = dict(
|
|
shuffle=True,
|
|
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
|
drop_last=training,
|
|
sampler=dataset.sampler if training else None,
|
|
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
|
|
batch_sampler=dataset.sampler,
|
|
)
|
|
|
|
return DataLoader(
|
|
dataset=dataset,
|
|
num_workers=cfg.dataset.workers,
|
|
collate_fn=collate_fn,
|
|
persistent_workers=cfg.dataset.workers > 1,
|
|
pin_memory=False,
|
|
worker_init_fn=_seed_worker,
|
|
**kwargs,
|
|
)
|
|
|
|
def _load_train_val_paths( val_ratio=0.1 ):
|
|
paths = []
|
|
train_paths = []
|
|
val_paths = []
|
|
|
|
for data_dir in cfg.dataset.training:
|
|
paths.extend(data_dir.rglob("*.jpg"))
|
|
paths.extend(data_dir.rglob("*.png"))
|
|
|
|
if len(paths) > 0:
|
|
random.seed(0)
|
|
random.shuffle(paths)
|
|
train_paths.extend(paths)
|
|
|
|
if len(cfg.dataset.validation) == 0:
|
|
val_len = math.floor(len(train_paths) * val_ratio)
|
|
train_len = math.floor(len(train_paths) * (1 - val_ratio))
|
|
|
|
val_paths = train_paths[:-val_len]
|
|
train_paths = train_paths[:train_len]
|
|
else:
|
|
paths = []
|
|
|
|
for data_dir in cfg.dataset.validation:
|
|
paths.extend(data_dir.rglob("*.jpg"))
|
|
paths.extend(data_dir.rglob("*.png"))
|
|
|
|
if len(paths) > 0:
|
|
random.seed(0)
|
|
random.shuffle(paths)
|
|
val_paths.extend(paths)
|
|
|
|
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
|
|
|
if len(train_paths) == 0:
|
|
raise RuntimeError(f"Failed to find any .png file in {cfg.dataset.training}.")
|
|
# to get it to shut up
|
|
if len(val_paths) == 0:
|
|
val_paths = [ train_paths[0] ]
|
|
|
|
return train_paths, val_paths
|
|
|
|
def create_datasets():
|
|
train_paths, val_paths = _load_train_val_paths()
|
|
|
|
train_dataset = Dataset(
|
|
train_paths,
|
|
training=True,
|
|
)
|
|
|
|
val_dataset = Dataset(
|
|
val_paths,
|
|
train_dataset.symmap,
|
|
)
|
|
|
|
val_dataset.head_(cfg.evaluation.size)
|
|
|
|
return train_dataset, val_dataset
|
|
|
|
def create_train_val_dataloader():
|
|
train_dataset, val_dataset = create_datasets()
|
|
|
|
# deepcopy is slow
|
|
subtrain_dataset = copy.deepcopy(train_dataset)
|
|
subtrain_dataset.head_(cfg.evaluation.size)
|
|
subtrain_dataset.training_(False)
|
|
|
|
train_dl = _create_dataloader(train_dataset, training=True)
|
|
val_dl = _create_dataloader(val_dataset, training=False)
|
|
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
|
|
|
|
_logger.info(str(train_dataset.symmap))
|
|
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
|
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
|
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
|
|
|
|
assert isinstance(subtrain_dl.dataset, Dataset)
|
|
|
|
return train_dl, subtrain_dl, val_dl
|
|
|
|
# parse dataset into better to sample metadata
|
|
"""
|
|
def create_dataset_metadata( skip_existing=True ):
|
|
symmap = get_symmap()
|
|
|
|
root = str(cfg.data_dir)
|
|
metadata_root = str(cfg.metadata_dir)
|
|
|
|
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
def add( dir, type="training", audios=True, texts=True ):
|
|
name = str(dir)
|
|
name = name.replace(root, "")
|
|
|
|
speaker_name = name
|
|
|
|
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
|
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
|
|
|
try:
|
|
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
|
except Exception as e:
|
|
metadata = {}
|
|
|
|
if not os.path.isdir(f'{root}/{name}/'):
|
|
return
|
|
|
|
# tqdm.write(f'{root}/{name}')
|
|
files = os.listdir(f'{root}/{name}/')
|
|
|
|
# grab IDs for every file
|
|
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
|
|
|
|
wrote = False
|
|
|
|
for id in tqdm(ids, desc=f"Processing {name}"):
|
|
try:
|
|
quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}')
|
|
|
|
if audios and not quant_path.exists():
|
|
continue
|
|
|
|
key = f'{type}/{speaker_name}/{id}'
|
|
|
|
if skip_existing and id in metadata:
|
|
continue
|
|
|
|
wrote = True
|
|
|
|
if id not in metadata:
|
|
metadata[id] = {}
|
|
|
|
utterance_metadata = {}
|
|
if audios:
|
|
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
|
|
dac = np.load(quant_path, allow_pickle=True)[()]
|
|
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
|
|
|
if "text" in dac["metadata"]:
|
|
utterance_metadata["text"] = dac["metadata"]["text"]
|
|
if "phonemes" in dac["metadata"]:
|
|
utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
|
|
if "language" in dac["metadata"]:
|
|
utterance_metadata["language"] = dac["metadata"]["language"]
|
|
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
|
|
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
|
|
|
|
for k, v in utterance_metadata.items():
|
|
metadata[id][k] = v
|
|
|
|
except Exception as e:
|
|
tqdm.write(f'Error while processing {id}: {e}')
|
|
|
|
if wrote:
|
|
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
|
f.write( json.dumps( metadata ) )
|
|
|
|
# training
|
|
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
|
add( data_dir, type="training" )
|
|
|
|
# validation
|
|
for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
|
|
add( data_dir, type="validation" )
|
|
|
|
# noise
|
|
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
|
|
add( data_dir, type="noise", texts=False )
|
|
|
|
# parse yaml to create an hdf5 file
|
|
def create_dataset_hdf5( skip_existing=True ):
|
|
cfg.dataset.use_hdf5 = True
|
|
cfg.load_hdf5(write=True)
|
|
hf = cfg.hdf5
|
|
|
|
symmap = get_symmap()
|
|
|
|
root = str(cfg.data_dir)
|
|
metadata_root = str(cfg.metadata_dir)
|
|
|
|
|
|
def add( dir, type="training", audios=True, texts=True ):
|
|
name = str(dir)
|
|
name = name.replace(root, "")
|
|
|
|
# yucky
|
|
speaker_name = name
|
|
if "LibriTTS-R" in speaker_name:
|
|
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
|
|
|
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
|
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
|
|
|
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
|
|
|
if not os.path.isdir(f'{root}/{name}/'):
|
|
return
|
|
|
|
files = os.listdir(f'{root}/{name}/')
|
|
|
|
# grab IDs for every file
|
|
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
|
|
|
|
for id in tqdm(ids, desc=f"Processing {name}"):
|
|
try:
|
|
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
|
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
|
|
|
if not quant_exists:
|
|
continue
|
|
|
|
key = f'{type}/{speaker_name}/{id}'
|
|
|
|
if skip_existing and key in hf:
|
|
continue
|
|
|
|
group = hf.create_group(key) if key not in hf else hf[key]
|
|
|
|
if id not in metadata:
|
|
metadata[id] = {}
|
|
|
|
utterance_metadata = {}
|
|
|
|
# audio
|
|
if audios:
|
|
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
|
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
|
|
|
if "text" in dac["metadata"]:
|
|
utterance_metadata["text"] = dac["metadata"]["text"]
|
|
if "phonemes" in dac["metadata"]:
|
|
utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
|
|
if "language" in dac["metadata"]:
|
|
utterance_metadata["language"] = dac["metadata"]["language"]
|
|
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
|
|
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
|
|
|
|
if "audio" not in group:
|
|
group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
|
|
|
|
# text
|
|
if texts:
|
|
if not utterance_metadata and text_exists:
|
|
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
|
|
|
phn = "".join(utterance_metadata["phonemes"])
|
|
phn = cfg.tokenizer.encode(phn)
|
|
phn = np.array(phn).astype(np.uint8)
|
|
|
|
if "text" not in group:
|
|
group.create_dataset('text', data=phn, compression='lzf')
|
|
|
|
for k, v in utterance_metadata.items():
|
|
group.attrs[k] = v
|
|
metadata[id][k] = v
|
|
|
|
except Exception as e:
|
|
tqdm.write(f'Error while processing {id}: {e}')
|
|
|
|
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
|
f.write( json.dumps( metadata ) )
|
|
|
|
# training
|
|
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
|
add( data_dir, type="training" )
|
|
|
|
# validation
|
|
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
|
|
add( data_dir, type="validation" )
|
|
|
|
# noise
|
|
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
|
|
add( data_dir, type="noise", texts=False )
|
|
|
|
# write symmap
|
|
if "symmap" in hf:
|
|
del hf['symmap']
|
|
|
|
hf.create_dataset('symmap', data=json.dumps(symmap))
|
|
hf.close()
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser("Save trained model to path.")
|
|
parser.add_argument("--action", type=str)
|
|
parser.add_argument("--tasks", type=str)
|
|
args, unknown = parser.parse_known_args()
|
|
|
|
task = args.action
|
|
|
|
cfg.dataset.workers = 1
|
|
|
|
if args.action == "hdf5":
|
|
create_dataset_hdf5()
|
|
elif args.action == "list-dataset":
|
|
dataset = []
|
|
for group in os.listdir(cfg.data_dir):
|
|
for name in os.listdir(cfg.data_dir / group):
|
|
if len(os.listdir(cfg.data_dir / group / name)) == 0:
|
|
continue
|
|
dataset.append(f'{group}/{name}')
|
|
|
|
_logger.info(json.dumps(dataset))
|
|
elif args.action == "metadata":
|
|
create_dataset_metadata()
|
|
elif args.action == "sample":
|
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
|
|
|
samples = {
|
|
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
|
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
|
#"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
|
}
|
|
|
|
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
|
|
|
|
for k, v in samples.items():
|
|
for i in range(len(v)):
|
|
for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
|
|
try:
|
|
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
|
|
except Exception as e:
|
|
_logger.info(f"Error while decoding prom {k}.{i}.{j}.wav: {str(e)}")
|
|
try:
|
|
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
|
|
except Exception as e:
|
|
_logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
|
|
v[i]['proms'][j] = v[i]['proms'][j].shape
|
|
v[i]['resps'][j] = v[i]['resps'][j].shape
|
|
|
|
for k, v in samples.items():
|
|
for i in range(len(v)):
|
|
_logger.info(f'{k}[{i}]: {v[i]}')
|
|
elif args.action == "validate":
|
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
|
|
|
missing = set()
|
|
|
|
for i in range(len( train_dl.dataset )):
|
|
batch = train_dl.dataset[i]
|
|
|
|
text = batch['text']
|
|
phonemes = batch['metadata']['phonemes']
|
|
|
|
decoded = [ cfg.tokenizer.decode(token) for token in text[1:-1] ]
|
|
for i, token in enumerate(decoded):
|
|
if token != "<unk>":
|
|
continue
|
|
|
|
phone = phonemes[i]
|
|
|
|
_logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
|
|
|
|
missing |= set([phone])
|
|
|
|
_logger.info( f"Missing tokens: {missing}" )
|
|
|
|
|
|
elif args.action == "tasks":
|
|
index = 0
|
|
cfg.dataset.tasks_list = args.tasks.split(",")
|
|
|
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
|
batch = next(iter(train_dl))
|
|
|
|
for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
|
|
if task not in cfg.dataset.tasks_list:
|
|
continue
|
|
|
|
_logger.info( f'{text} {task} {cfg.model.resp_levels}')
|
|
_logger.info( f'{proms.shape} {resps.shape}' )
|
|
|
|
tokens = 0
|
|
tokens += sum([ text.shape[0] for text in batch["text"] ])
|
|
tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
|
|
_logger.info( f'{tokens}' )
|
|
|
|
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
|
|
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
|
|
break
|
|
""" |