actually fix validation of phonemes in the symmap

This commit is contained in:
mrq 2024-09-21 12:19:34 -05:00
parent c8d4716a9f
commit 769f67dcfe
2 changed files with 40 additions and 38 deletions

View File

@ -16,6 +16,7 @@ from .emb.g2p import encode as encode_phns
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
from .utils.distributed import global_rank, local_rank, world_size
from .utils.io import torch_save, torch_load, json_read, json_write, json_stringify, json_parse
from .utils import setup_logging
from collections import defaultdict
from functools import cache, cached_property
@ -594,7 +595,7 @@ class Dataset(_Dataset):
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path"
self.sampler_order = cfg.dataset.sample_order
self.sampler_shuffle = cfg.dataset.sample_shuffle
self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
@ -878,21 +879,21 @@ class Dataset(_Dataset):
return path, text, resps
def get_similar_utterance(self, spkr_name, reference, offset=0 ):
# lots of boilerplate checks
metadata_path = cfg.metadata_dir / f"{spkr_name}.json"
if not metadata_path.exists():
return None
metadata = json_read( metadata_path )
metadata = json_read( cfg.metadata_dir / f"{spkr_name}.json", default={} )
if reference not in metadata:
return None
reference_metadata = metadata[reference]
if "similar" not in reference_metadata:
return None
if len(reference_metadata["similar"]) >= offset:
offset = -1
offset = 0
metadata_keys = list(metadata.keys())
index = reference_metadata["similar"][offset]
name = metadata_keys[index]
name = metadata_keys[reference_metadata["similar"][offset]]
return name
def sample_prompts(self, spkr_name, reference, should_trim=True):
@ -1259,7 +1260,7 @@ def _seed_worker(worker_id):
def _create_dataloader(dataset, training):
kwargs = dict(
shuffle=False,
shuffle=not training,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
drop_last=training,
sampler=dataset.sampler if training else None,
@ -1470,7 +1471,7 @@ def create_dataset_hdf5( skip_existing=True ):
metadata_root = str(cfg.metadata_dir)
def add( dir, type="training", audios=True, texts=True ):
def add( dir, type="training", audios=True, texts=True, verbose=False ):
name = str(dir)
name = name.replace(root, "")
speaker_name = remap_speaker_name( name )
@ -1512,7 +1513,7 @@ def create_dataset_hdf5( skip_existing=True ):
group.create_dataset('text', data=phn, compression='lzf')
"""
for id in tqdm(ids, desc=f"Processing {name}"):
for id in tqdm(ids, desc=f"Processing {name}", disable=not verbose):
try:
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
@ -1594,6 +1595,7 @@ if __name__ == "__main__":
task = args.action
setup_logging()
cfg.dataset.workers = 1
if args.action == "hdf5":
@ -1641,29 +1643,12 @@ if __name__ == "__main__":
_logger.info(f'{k}[{i}]: {v[i]}')
elif args.action == "validate":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
missing = set()
dataset = train_dl.dataset
missing = []
symmap = get_phone_symmap()
for index in tqdm(range(len( dataset )), desc="Processing dataset..."):
"""
batch = train_dl.dataset[i]
text = batch['text']
phonemes = batch['metadata']['phonemes']
decoded = [ cfg.tokenizer.decode(token) for token in text[1:-1] ]
for i, token in enumerate(decoded):
if token != "<unk>":
continue
phone = phonemes[i]
_logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
missing |= set([phone])
"""
if dataset.sampler_type == "group":
spkr_group = dataset.spkr_groups[index]
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
@ -1688,21 +1673,36 @@ if __name__ == "__main__":
if key not in cfg.hdf5:
continue
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
text = cfg.hdf5[key]["text"][:]
else:
_, metadata = _load_quants(path, return_metadata=True)
text = tokenize( phonemes )
phonemes = metadata["phonemes"]
for i, phone in enumerate( phonemes ):
if phone in symmap:
continue
if phone in missing:
continue
_logger.info( f"{path} | {phonemes}[{i}] | {phone}" )
missing.append( phone )
"""
text = tokenize( phonemes )[1:-1]
unk_token = tokenize("<unk>")[1]
if unk_token in text:
print( unk_token, text, phonemes )
for i, token in enumerate(text):
if token != "<unk>":
if token != unk_token:
continue
phone = phonemes[i]
_logger.info( f"{path}: {phonemes}: {phone}" )
if phone not in missing:
_logger.info( f"{path} | {phonemes}[{i}] | {phone}" )
missing |= set([phone])
"""
_logger.info( f"Missing tokens: {missing}" )

View File

@ -80,6 +80,8 @@ def state_dict_to_tensor_metadata( data: dict, module_key=None ):
# not a dict of tensors, put it as metadata
try:
metadata[k] = json.dumps(v)
if isinstance( metadata[k], bytes ):
metadata[k] = metadata[k].decode('utf-8')
except Exception as e:
pass