diff --git a/vall_e/data.py b/vall_e/data.py index e0d626d..975505c 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -15,7 +15,7 @@ from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge from .emb.g2p import encode as encode_phns from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler from .utils.distributed import global_rank, local_rank, world_size -from .utils.io import torch_save, torch_load, json_read, json_write +from .utils.io import torch_save, torch_load, json_read, json_write, json_stringify, json_parse from collections import defaultdict from functools import cache, cached_property @@ -473,7 +473,7 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): if cfg.dataset.use_metadata and metadata_path.exists(): #metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) - metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) + metadata = json_read( metadata_path ) if len(metadata) == 0: return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate ) @@ -554,7 +554,8 @@ def _get_phones(path): phone_path = _get_phone_path(path) quant_path = _get_quant_path(path) if phone_path.exists(): - metadata = json.loads(open(phone_path, "r", encoding="utf-8").read()) + #metadata = json.loads(open(phone_path, "r", encoding="utf-8").read()) + metadata = json_read(phone_path) elif quant_path.exists(): _, metadata = _load_quants( path, return_metadata=True ) else: @@ -879,7 +880,7 @@ class Dataset(_Dataset): metadata_path = Path(f"{metadata_root}/{speaker_name}.json") if not metadata_path.exists(): return None - metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) + metadata = json_read( metadata_path ) if reference not in metadata: return None reference_metadata = metadata[reference] @@ -1335,18 +1336,28 @@ def create_train_val_dataloader(): return train_dl, subtrain_dl, val_dl +# parse metadata from an numpy file (.enc/.dac) and validate it def process_artifact_metadata( artifact ): metadata = {} + # text transcription (just in case) if "text" in artifact["metadata"]: metadata["text"] = artifact["metadata"]["text"] + # phonemization of text transcription (just in case) if "phonemes" in artifact["metadata"]: metadata["phonemes"] = artifact["metadata"]["phonemes"] + # language for sampling / input creation if "language" in artifact["metadata"]: metadata["language"] = artifact["metadata"]["language"] - if "original_length" in artifact["metadata"] and "sample_rate" in artifact["metadata"]: + # top-k similar utterances for this utternace + if "similar" in artifact["metadata"]: + metadata["similar"] = artifact["metadata"]["similar"] + # duration for use of culling / sorting dataset + if "duration" in artifact["metadata"]: + metadata["duration"] = duration + # derive duration from sample count / sample rate + elif "original_length" in artifact["metadata"] and "sample_rate" in artifact["metadata"]: metadata["duration"] = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"] - # rephonemize if required if "phonemes" not in metadata and "text" in metadata: metadata["phonemes"] = encode_phns( metadata["text"], language=metadata["language"] if "language" in metadata["language"] else "en" ) @@ -1361,6 +1372,16 @@ def process_artifact_metadata( artifact ): return metadata +# yucky, but I would like to have the LibriTTS-R utterances remapped to their LibriSpeech counterpart +# to-do: allow this to be adjusted without having to regenerate metadata / HDF5 by remapping name during dataloader creation +def remap_speaker_name( name ): + # commented out because I don't want the LibriSpeech portion of the dataset to get added + """ + if "LibriTTS-R" in speaker_name: + name = name.replace("LibriTTS-R", "LibriVox") + """ + return name + # parse dataset into better to sample metadata def create_dataset_metadata( skip_existing=True ): symmap = get_phone_symmap() @@ -1373,25 +1394,16 @@ def create_dataset_metadata( skip_existing=True ): def add( dir, type="training", audios=True, texts=True ): name = str(dir) name = name.replace(root, "") - - speaker_name = name - """ - if "LibriTTS-R" in speaker_name: - speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox") - """ + speaker_name = remap_speaker_name( name ) metadata_path = Path(f"{metadata_root}/{speaker_name}.json") metadata_path.parents[0].mkdir(parents=True, exist_ok=True) - try: - metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) - except Exception as e: - metadata = {} + metadata = json_read( metadata_path, default={} ) if not os.path.isdir(f'{root}/{name}/'): return - # tqdm.write(f'{root}/{name}') files = os.listdir(f'{root}/{name}/') # grab IDs for every file @@ -1430,8 +1442,7 @@ def create_dataset_metadata( skip_existing=True ): tqdm.write(f'Error while processing {id}: {e}') if wrote: - with open(str(metadata_path), "w", encoding="utf-8") as f: - f.write( json.dumps( metadata ) ) + json_write( metadata, metadata_path ) # training for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"): @@ -1460,16 +1471,12 @@ def create_dataset_hdf5( skip_existing=True ): def add( dir, type="training", audios=True, texts=True ): name = str(dir) name = name.replace(root, "") - - # yucky - speaker_name = name - if "LibriTTS-R" in speaker_name: - speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox") + speaker_name = remap_speaker_name( name ) metadata_path = Path(f"{metadata_root}/{speaker_name}.json") metadata_path.parents[0].mkdir(parents=True, exist_ok=True) - metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) + metadata = json_read(metadata_path, default={}) if not os.path.isdir(f'{root}/{name}/'): return @@ -1534,9 +1541,11 @@ def create_dataset_hdf5( skip_existing=True ): group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf') # text + # this is a relic from when I did have the quantized audio and phoneme transcription separate + # to-do: ensure I can remove this block if texts: if not utterance_metadata and text_exists: - utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) + utterance_metadata = json_read(f'{root}/{name}/{id}{_get_phone_extension()}') phn = "".join(utterance_metadata["phonemes"]) phn = cfg.tokenizer.encode(phn) @@ -1552,8 +1561,7 @@ def create_dataset_hdf5( skip_existing=True ): except Exception as e: tqdm.write(f'Error while processing {id}: {e}') - with open(str(metadata_path), "w", encoding="utf-8") as f: - f.write( json.dumps( metadata ) ) + json_write( metadata, metadata_path ) # training for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"): @@ -1571,7 +1579,7 @@ def create_dataset_hdf5( skip_existing=True ): if "symmap" in hf: del hf['symmap'] - hf.create_dataset('symmap', data=json.dumps(symmap)) + hf.create_dataset('symmap', data=json_stringify(symmap)) hf.close() if __name__ == "__main__": @@ -1596,7 +1604,7 @@ if __name__ == "__main__": continue dataset.append(f'{group}/{name}') - _logger.info(json.dumps(dataset)) + _logger.info(json_stringify(dataset)) elif args.action == "metadata": create_dataset_metadata() elif args.action == "sample": diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index 5f68877..4453f48 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -54,6 +54,7 @@ def process( verbose=False, metadata_path=None, top_k=8, + metadata_keys=[], trim_duration=0, min_duration=0, @@ -73,13 +74,16 @@ def process( if tts is None: tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype ) - features = {} + features = { key: None for key in metadata_keys } mfcc = None simplified_metadata = True # aims to slim down the raw data in the JSON to store slop = True # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier + if not speaker_path.exists(): + return + # compute features (embeddings if quantized already, MFCC features if raw audio) for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path.name}'", disable=not verbose): extension = filename.split(".")[-1] @@ -92,11 +96,13 @@ def process( artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()] duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"] + """ if 0 < min_duration and duration < min_duration: continue if 0 < max_duration and max_duration < duration: continue + """ lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en" if "phonemes" in artifact["metadata"]: @@ -178,22 +184,35 @@ def process( # do batch cosine similarity processing keys = list(features.keys()) - embeddings = torch.stack( list( features.values() ) ) + top_k = min( top_k, len(keys) ) + + if top_k == 0: + return + + null_embedding = torch.zeros( (1024,), device=tts.device, dtype=tts.dtype ) + embeddings = torch.stack( [ feature if feature is not None else null_embedding for feature in features.values() ] ) sorted_similarities = {} + for index, filename in tqdm(enumerate(keys), total=len(keys), desc=f"Computing similarities: {speaker_path.name}"): + if features[filename] is None: + continue + embedding = features[filename].unsqueeze(0) similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1) + + # sorting is slow, don't bother + #sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True) + # set current index to -inf similarities[index] = float("-inf") - similarities = torch.topk(similarities, k=top_k, largest=True, sorted=True).indices.tolist() - # similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist() + + topk = torch.topk(similarities, k=top_k, largest=True, sorted=True) + similarities = [ (index, keys[index], score) for index, score in zip( topk.indices.tolist(), topk.values.tolist() ) ] sorted_similarities[filename] = similarities - # sorting is slow, don't bother - #sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True) return sorted_similarities @@ -221,6 +240,8 @@ def main(): args = parser.parse_args() + args.skip_existing = False # + if args.use_dataset: cfg.metadata_dir.mkdir(parents=True, exist_ok=True) @@ -228,10 +249,17 @@ def main(): name = str(dir) name = name.replace(str(cfg.data_dir), "") speaker_name = name + """ if "LibriTTS-R" in speaker_name: speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox") + """ metadata_path = cfg.metadata_dir / f'{speaker_name}.json' + metadata = json_read( metadata_path, default={} ) + metadata_keys = list(metadata.keys()) if metadata else [] + + if args.skip_existing and metadata_keys and "similar" in metadata[metadata_keys[-1]]: + return similarities = process( speaker_path=cfg.data_dir / speaker_name, @@ -242,6 +270,7 @@ def main(): #min_duration=args.min_duration, #max_duration=args.max_duration, storage_backend=args.storage_backend, + metadata_keys=metadata_keys, audio_backend=args.audio_backend, device=args.device, @@ -250,29 +279,23 @@ def main(): verbose=True, ) + + if not similarities: + return if args.storage_backend == "faiss": faiss.write_index(similarities, str(metadata_path.with_suffix(".faiss"))) return - - #metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else {} - metadata = json_read( metadata_path, default={} ) - metadata_keys = list(metadata.keys()) if metadata else list(similarities.keys()) - - for filename, sim in similarities.items(): + + for filename, similar in similarities.items(): if filename not in metadata: metadata[filename] = {} - metadata[filename]["similar"] = sim + # overkill but i'm very paranoid about mismatching indices + metadata[filename]["similar"] = [ metadata_keys.index(s[1]) for s in similar ] json_write( metadata, metadata_path ) - """ - with open(str(metadata_path), "wb") as f: - f.write( json.dumps( metadata ) ) - #f.write( truncate_json( json.dumps( metadata ) ) ) - """ - # training for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"): add( data_dir, type="training" ) @@ -286,7 +309,7 @@ def main(): add( data_dir, type="noise", texts=False ) elif args.input_speaker: - process( + similarities = process( speaker_path=args.input_speaker, yaml=args.yaml, text=args.text, @@ -304,6 +327,10 @@ def main(): storage_backend=args.storage_backend, verbose=True, ) + + # and print + for filename, sim in similarities.items(): + print(f'{filename}: {sim}') else: raise Exception("!") diff --git a/vall_e/utils/io.py b/vall_e/utils/io.py index cba071f..8030ff5 100644 --- a/vall_e/utils/io.py +++ b/vall_e/utils/io.py @@ -11,7 +11,11 @@ try: except: import json -def json_stringify( data ): +from .utils import truncate_json + +def json_stringify( data, truncate=False ): + if truncate: + return truncate_json( json.dumps( data ) ) return json.dumps( data ) def json_parse( string ): @@ -26,11 +30,11 @@ def json_read( path, default=None ): with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f: return json_parse( f.read() ) -def json_write( data, path ): +def json_write( data, path, truncate=False ): path = coerce_path( path ) with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f: - f.write( json_stringify( data ) ) + f.write( json_stringify( data, truncate=truncate ) ) def coerce_path( path ): return path if isinstance( path, Path ) else Path(path)