actually fix validation of phonemes in the symmap
This commit is contained in:
parent
c8d4716a9f
commit
769f67dcfe
|
@ -16,6 +16,7 @@ from .emb.g2p import encode as encode_phns
|
||||||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||||
from .utils.distributed import global_rank, local_rank, world_size
|
from .utils.distributed import global_rank, local_rank, world_size
|
||||||
from .utils.io import torch_save, torch_load, json_read, json_write, json_stringify, json_parse
|
from .utils.io import torch_save, torch_load, json_read, json_write, json_stringify, json_parse
|
||||||
|
from .utils import setup_logging
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import cache, cached_property
|
from functools import cache, cached_property
|
||||||
|
@ -594,7 +595,7 @@ class Dataset(_Dataset):
|
||||||
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
||||||
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path"
|
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path"
|
||||||
self.sampler_order = cfg.dataset.sample_order
|
self.sampler_order = cfg.dataset.sample_order
|
||||||
self.sampler_shuffle = cfg.dataset.sample_shuffle
|
self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True
|
||||||
|
|
||||||
# to-do: do not do validation if there's nothing in the validation
|
# to-do: do not do validation if there's nothing in the validation
|
||||||
# this just makes it be happy
|
# this just makes it be happy
|
||||||
|
@ -878,21 +879,21 @@ class Dataset(_Dataset):
|
||||||
return path, text, resps
|
return path, text, resps
|
||||||
|
|
||||||
def get_similar_utterance(self, spkr_name, reference, offset=0 ):
|
def get_similar_utterance(self, spkr_name, reference, offset=0 ):
|
||||||
# lots of boilerplate checks
|
metadata = json_read( cfg.metadata_dir / f"{spkr_name}.json", default={} )
|
||||||
metadata_path = cfg.metadata_dir / f"{spkr_name}.json"
|
|
||||||
if not metadata_path.exists():
|
|
||||||
return None
|
|
||||||
metadata = json_read( metadata_path )
|
|
||||||
if reference not in metadata:
|
if reference not in metadata:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
reference_metadata = metadata[reference]
|
reference_metadata = metadata[reference]
|
||||||
|
|
||||||
if "similar" not in reference_metadata:
|
if "similar" not in reference_metadata:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if len(reference_metadata["similar"]) >= offset:
|
if len(reference_metadata["similar"]) >= offset:
|
||||||
offset = -1
|
offset = 0
|
||||||
|
|
||||||
metadata_keys = list(metadata.keys())
|
metadata_keys = list(metadata.keys())
|
||||||
index = reference_metadata["similar"][offset]
|
name = metadata_keys[reference_metadata["similar"][offset]]
|
||||||
name = metadata_keys[index]
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
||||||
|
@ -1259,7 +1260,7 @@ def _seed_worker(worker_id):
|
||||||
|
|
||||||
def _create_dataloader(dataset, training):
|
def _create_dataloader(dataset, training):
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
shuffle=False,
|
shuffle=not training,
|
||||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||||
drop_last=training,
|
drop_last=training,
|
||||||
sampler=dataset.sampler if training else None,
|
sampler=dataset.sampler if training else None,
|
||||||
|
@ -1470,7 +1471,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
metadata_root = str(cfg.metadata_dir)
|
metadata_root = str(cfg.metadata_dir)
|
||||||
|
|
||||||
|
|
||||||
def add( dir, type="training", audios=True, texts=True ):
|
def add( dir, type="training", audios=True, texts=True, verbose=False ):
|
||||||
name = str(dir)
|
name = str(dir)
|
||||||
name = name.replace(root, "")
|
name = name.replace(root, "")
|
||||||
speaker_name = remap_speaker_name( name )
|
speaker_name = remap_speaker_name( name )
|
||||||
|
@ -1512,7 +1513,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
group.create_dataset('text', data=phn, compression='lzf')
|
group.create_dataset('text', data=phn, compression='lzf')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
for id in tqdm(ids, desc=f"Processing {name}", disable=not verbose):
|
||||||
try:
|
try:
|
||||||
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
||||||
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
||||||
|
@ -1594,6 +1595,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
task = args.action
|
task = args.action
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
cfg.dataset.workers = 1
|
cfg.dataset.workers = 1
|
||||||
|
|
||||||
if args.action == "hdf5":
|
if args.action == "hdf5":
|
||||||
|
@ -1641,29 +1643,12 @@ if __name__ == "__main__":
|
||||||
_logger.info(f'{k}[{i}]: {v[i]}')
|
_logger.info(f'{k}[{i}]: {v[i]}')
|
||||||
elif args.action == "validate":
|
elif args.action == "validate":
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
|
|
||||||
missing = set()
|
|
||||||
|
|
||||||
dataset = train_dl.dataset
|
dataset = train_dl.dataset
|
||||||
|
|
||||||
|
missing = []
|
||||||
|
symmap = get_phone_symmap()
|
||||||
|
|
||||||
for index in tqdm(range(len( dataset )), desc="Processing dataset..."):
|
for index in tqdm(range(len( dataset )), desc="Processing dataset..."):
|
||||||
"""
|
|
||||||
batch = train_dl.dataset[i]
|
|
||||||
text = batch['text']
|
|
||||||
phonemes = batch['metadata']['phonemes']
|
|
||||||
|
|
||||||
decoded = [ cfg.tokenizer.decode(token) for token in text[1:-1] ]
|
|
||||||
for i, token in enumerate(decoded):
|
|
||||||
if token != "<unk>":
|
|
||||||
continue
|
|
||||||
|
|
||||||
phone = phonemes[i]
|
|
||||||
|
|
||||||
_logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
|
|
||||||
|
|
||||||
missing |= set([phone])
|
|
||||||
"""
|
|
||||||
|
|
||||||
if dataset.sampler_type == "group":
|
if dataset.sampler_type == "group":
|
||||||
spkr_group = dataset.spkr_groups[index]
|
spkr_group = dataset.spkr_groups[index]
|
||||||
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
|
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
|
||||||
|
@ -1688,21 +1673,36 @@ if __name__ == "__main__":
|
||||||
if key not in cfg.hdf5:
|
if key not in cfg.hdf5:
|
||||||
continue
|
continue
|
||||||
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
||||||
text = cfg.hdf5[key]["text"][:]
|
|
||||||
else:
|
else:
|
||||||
_, metadata = _load_quants(path, return_metadata=True)
|
_, metadata = _load_quants(path, return_metadata=True)
|
||||||
text = tokenize( phonemes )
|
|
||||||
phonemes = metadata["phonemes"]
|
phonemes = metadata["phonemes"]
|
||||||
|
|
||||||
|
for i, phone in enumerate( phonemes ):
|
||||||
|
if phone in symmap:
|
||||||
|
continue
|
||||||
|
if phone in missing:
|
||||||
|
continue
|
||||||
|
|
||||||
|
_logger.info( f"{path} | {phonemes}[{i}] | {phone}" )
|
||||||
|
missing.append( phone )
|
||||||
|
|
||||||
|
"""
|
||||||
|
text = tokenize( phonemes )[1:-1]
|
||||||
|
unk_token = tokenize("<unk>")[1]
|
||||||
|
|
||||||
|
if unk_token in text:
|
||||||
|
print( unk_token, text, phonemes )
|
||||||
|
|
||||||
for i, token in enumerate(text):
|
for i, token in enumerate(text):
|
||||||
if token != "<unk>":
|
if token != unk_token:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
phone = phonemes[i]
|
phone = phonemes[i]
|
||||||
|
if phone not in missing:
|
||||||
_logger.info( f"{path}: {phonemes}: {phone}" )
|
_logger.info( f"{path} | {phonemes}[{i}] | {phone}" )
|
||||||
|
|
||||||
missing |= set([phone])
|
missing |= set([phone])
|
||||||
|
"""
|
||||||
|
|
||||||
_logger.info( f"Missing tokens: {missing}" )
|
_logger.info( f"Missing tokens: {missing}" )
|
||||||
|
|
||||||
|
|
|
@ -80,6 +80,8 @@ def state_dict_to_tensor_metadata( data: dict, module_key=None ):
|
||||||
# not a dict of tensors, put it as metadata
|
# not a dict of tensors, put it as metadata
|
||||||
try:
|
try:
|
||||||
metadata[k] = json.dumps(v)
|
metadata[k] = json.dumps(v)
|
||||||
|
if isinstance( metadata[k], bytes ):
|
||||||
|
metadata[k] = metadata[k].decode('utf-8')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user