actually fix validation of phonemes in the symmap
This commit is contained in:
parent
c8d4716a9f
commit
769f67dcfe
|
@ -16,6 +16,7 @@ from .emb.g2p import encode as encode_phns
|
|||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||
from .utils.distributed import global_rank, local_rank, world_size
|
||||
from .utils.io import torch_save, torch_load, json_read, json_write, json_stringify, json_parse
|
||||
from .utils import setup_logging
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import cache, cached_property
|
||||
|
@ -594,7 +595,7 @@ class Dataset(_Dataset):
|
|||
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
||||
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path"
|
||||
self.sampler_order = cfg.dataset.sample_order
|
||||
self.sampler_shuffle = cfg.dataset.sample_shuffle
|
||||
self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True
|
||||
|
||||
# to-do: do not do validation if there's nothing in the validation
|
||||
# this just makes it be happy
|
||||
|
@ -878,21 +879,21 @@ class Dataset(_Dataset):
|
|||
return path, text, resps
|
||||
|
||||
def get_similar_utterance(self, spkr_name, reference, offset=0 ):
|
||||
# lots of boilerplate checks
|
||||
metadata_path = cfg.metadata_dir / f"{spkr_name}.json"
|
||||
if not metadata_path.exists():
|
||||
return None
|
||||
metadata = json_read( metadata_path )
|
||||
metadata = json_read( cfg.metadata_dir / f"{spkr_name}.json", default={} )
|
||||
|
||||
if reference not in metadata:
|
||||
return None
|
||||
|
||||
reference_metadata = metadata[reference]
|
||||
|
||||
if "similar" not in reference_metadata:
|
||||
return None
|
||||
|
||||
if len(reference_metadata["similar"]) >= offset:
|
||||
offset = -1
|
||||
offset = 0
|
||||
|
||||
metadata_keys = list(metadata.keys())
|
||||
index = reference_metadata["similar"][offset]
|
||||
name = metadata_keys[index]
|
||||
name = metadata_keys[reference_metadata["similar"][offset]]
|
||||
return name
|
||||
|
||||
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
||||
|
@ -1259,7 +1260,7 @@ def _seed_worker(worker_id):
|
|||
|
||||
def _create_dataloader(dataset, training):
|
||||
kwargs = dict(
|
||||
shuffle=False,
|
||||
shuffle=not training,
|
||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||
drop_last=training,
|
||||
sampler=dataset.sampler if training else None,
|
||||
|
@ -1470,7 +1471,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
metadata_root = str(cfg.metadata_dir)
|
||||
|
||||
|
||||
def add( dir, type="training", audios=True, texts=True ):
|
||||
def add( dir, type="training", audios=True, texts=True, verbose=False ):
|
||||
name = str(dir)
|
||||
name = name.replace(root, "")
|
||||
speaker_name = remap_speaker_name( name )
|
||||
|
@ -1512,7 +1513,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
group.create_dataset('text', data=phn, compression='lzf')
|
||||
"""
|
||||
|
||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||
for id in tqdm(ids, desc=f"Processing {name}", disable=not verbose):
|
||||
try:
|
||||
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
||||
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
||||
|
@ -1594,6 +1595,7 @@ if __name__ == "__main__":
|
|||
|
||||
task = args.action
|
||||
|
||||
setup_logging()
|
||||
cfg.dataset.workers = 1
|
||||
|
||||
if args.action == "hdf5":
|
||||
|
@ -1641,29 +1643,12 @@ if __name__ == "__main__":
|
|||
_logger.info(f'{k}[{i}]: {v[i]}')
|
||||
elif args.action == "validate":
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
||||
missing = set()
|
||||
|
||||
dataset = train_dl.dataset
|
||||
|
||||
missing = []
|
||||
symmap = get_phone_symmap()
|
||||
|
||||
for index in tqdm(range(len( dataset )), desc="Processing dataset..."):
|
||||
"""
|
||||
batch = train_dl.dataset[i]
|
||||
text = batch['text']
|
||||
phonemes = batch['metadata']['phonemes']
|
||||
|
||||
decoded = [ cfg.tokenizer.decode(token) for token in text[1:-1] ]
|
||||
for i, token in enumerate(decoded):
|
||||
if token != "<unk>":
|
||||
continue
|
||||
|
||||
phone = phonemes[i]
|
||||
|
||||
_logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
|
||||
|
||||
missing |= set([phone])
|
||||
"""
|
||||
|
||||
if dataset.sampler_type == "group":
|
||||
spkr_group = dataset.spkr_groups[index]
|
||||
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
|
||||
|
@ -1688,21 +1673,36 @@ if __name__ == "__main__":
|
|||
if key not in cfg.hdf5:
|
||||
continue
|
||||
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
||||
text = cfg.hdf5[key]["text"][:]
|
||||
else:
|
||||
_, metadata = _load_quants(path, return_metadata=True)
|
||||
text = tokenize( phonemes )
|
||||
|
||||
phonemes = metadata["phonemes"]
|
||||
|
||||
for i, phone in enumerate( phonemes ):
|
||||
if phone in symmap:
|
||||
continue
|
||||
if phone in missing:
|
||||
continue
|
||||
|
||||
_logger.info( f"{path} | {phonemes}[{i}] | {phone}" )
|
||||
missing.append( phone )
|
||||
|
||||
"""
|
||||
text = tokenize( phonemes )[1:-1]
|
||||
unk_token = tokenize("<unk>")[1]
|
||||
|
||||
if unk_token in text:
|
||||
print( unk_token, text, phonemes )
|
||||
|
||||
for i, token in enumerate(text):
|
||||
if token != "<unk>":
|
||||
if token != unk_token:
|
||||
continue
|
||||
|
||||
phone = phonemes[i]
|
||||
|
||||
_logger.info( f"{path}: {phonemes}: {phone}" )
|
||||
|
||||
if phone not in missing:
|
||||
_logger.info( f"{path} | {phonemes}[{i}] | {phone}" )
|
||||
missing |= set([phone])
|
||||
"""
|
||||
|
||||
_logger.info( f"Missing tokens: {missing}" )
|
||||
|
||||
|
|
|
@ -80,6 +80,8 @@ def state_dict_to_tensor_metadata( data: dict, module_key=None ):
|
|||
# not a dict of tensors, put it as metadata
|
||||
try:
|
||||
metadata[k] = json.dumps(v)
|
||||
if isinstance( metadata[k], bytes ):
|
||||
metadata[k] = metadata[k].decode('utf-8')
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user