From 769f67dcfeffc4b8a0c425e15bd399e0f6f5e154 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 21 Sep 2024 12:19:34 -0500 Subject: [PATCH] actually fix validation of phonemes in the symmap --- vall_e/data.py | 76 +++++++++++++++++++++++----------------------- vall_e/utils/io.py | 2 ++ 2 files changed, 40 insertions(+), 38 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 6249cb6..fb2b0f2 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -16,6 +16,7 @@ from .emb.g2p import encode as encode_phns from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler from .utils.distributed import global_rank, local_rank, world_size from .utils.io import torch_save, torch_load, json_read, json_write, json_stringify, json_parse +from .utils import setup_logging from collections import defaultdict from functools import cache, cached_property @@ -594,7 +595,7 @@ class Dataset(_Dataset): self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path" self.sampler_order = cfg.dataset.sample_order - self.sampler_shuffle = cfg.dataset.sample_shuffle + self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True # to-do: do not do validation if there's nothing in the validation # this just makes it be happy @@ -878,21 +879,21 @@ class Dataset(_Dataset): return path, text, resps def get_similar_utterance(self, spkr_name, reference, offset=0 ): - # lots of boilerplate checks - metadata_path = cfg.metadata_dir / f"{spkr_name}.json" - if not metadata_path.exists(): - return None - metadata = json_read( metadata_path ) + metadata = json_read( cfg.metadata_dir / f"{spkr_name}.json", default={} ) + if reference not in metadata: return None + reference_metadata = metadata[reference] + if "similar" not in reference_metadata: return None + if len(reference_metadata["similar"]) >= offset: - offset = -1 + offset = 0 + metadata_keys = list(metadata.keys()) - index = reference_metadata["similar"][offset] - name = metadata_keys[index] + name = metadata_keys[reference_metadata["similar"][offset]] return name def sample_prompts(self, spkr_name, reference, should_trim=True): @@ -1259,7 +1260,7 @@ def _seed_worker(worker_id): def _create_dataloader(dataset, training): kwargs = dict( - shuffle=False, + shuffle=not training, batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, drop_last=training, sampler=dataset.sampler if training else None, @@ -1470,7 +1471,7 @@ def create_dataset_hdf5( skip_existing=True ): metadata_root = str(cfg.metadata_dir) - def add( dir, type="training", audios=True, texts=True ): + def add( dir, type="training", audios=True, texts=True, verbose=False ): name = str(dir) name = name.replace(root, "") speaker_name = remap_speaker_name( name ) @@ -1512,7 +1513,7 @@ def create_dataset_hdf5( skip_existing=True ): group.create_dataset('text', data=phn, compression='lzf') """ - for id in tqdm(ids, desc=f"Processing {name}"): + for id in tqdm(ids, desc=f"Processing {name}", disable=not verbose): try: quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True @@ -1594,6 +1595,7 @@ if __name__ == "__main__": task = args.action + setup_logging() cfg.dataset.workers = 1 if args.action == "hdf5": @@ -1641,29 +1643,12 @@ if __name__ == "__main__": _logger.info(f'{k}[{i}]: {v[i]}') elif args.action == "validate": train_dl, subtrain_dl, val_dl = create_train_val_dataloader() - - missing = set() - dataset = train_dl.dataset + missing = [] + symmap = get_phone_symmap() + for index in tqdm(range(len( dataset )), desc="Processing dataset..."): - """ - batch = train_dl.dataset[i] - text = batch['text'] - phonemes = batch['metadata']['phonemes'] - - decoded = [ cfg.tokenizer.decode(token) for token in text[1:-1] ] - for i, token in enumerate(decoded): - if token != "": - continue - - phone = phonemes[i] - - _logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" ) - - missing |= set([phone]) - """ - if dataset.sampler_type == "group": spkr_group = dataset.spkr_groups[index] #spkr_group_id = dataset.spkr_group_symmap[spkr_group] @@ -1688,21 +1673,36 @@ if __name__ == "__main__": if key not in cfg.hdf5: continue metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() } - text = cfg.hdf5[key]["text"][:] else: _, metadata = _load_quants(path, return_metadata=True) - text = tokenize( phonemes ) + phonemes = metadata["phonemes"] + for i, phone in enumerate( phonemes ): + if phone in symmap: + continue + if phone in missing: + continue + + _logger.info( f"{path} | {phonemes}[{i}] | {phone}" ) + missing.append( phone ) + + """ + text = tokenize( phonemes )[1:-1] + unk_token = tokenize("")[1] + + if unk_token in text: + print( unk_token, text, phonemes ) + for i, token in enumerate(text): - if token != "": + if token != unk_token: continue phone = phonemes[i] - - _logger.info( f"{path}: {phonemes}: {phone}" ) - + if phone not in missing: + _logger.info( f"{path} | {phonemes}[{i}] | {phone}" ) missing |= set([phone]) + """ _logger.info( f"Missing tokens: {missing}" ) diff --git a/vall_e/utils/io.py b/vall_e/utils/io.py index b2d17c5..2a4e3c7 100644 --- a/vall_e/utils/io.py +++ b/vall_e/utils/io.py @@ -80,6 +80,8 @@ def state_dict_to_tensor_metadata( data: dict, module_key=None ): # not a dict of tensors, put it as metadata try: metadata[k] = json.dumps(v) + if isinstance( metadata[k], bytes ): + metadata[k] = metadata[k].decode('utf-8') except Exception as e: pass