From f00283440cb0411b200dafa80ec5f638f258c624 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 17 Sep 2024 22:26:31 -0500 Subject: [PATCH] faster --- vall_e/emb/similar.py | 176 +++++++++++------------------------------- 1 file changed, 47 insertions(+), 129 deletions(-) diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index a8617a9..2b44156 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -58,7 +58,7 @@ def process( min_duration=0, max_duration=0, - storage_backend="local" + storage_backend="slop" ): global tts @@ -72,23 +72,15 @@ def process( if tts is None: tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype ) - queue = [] features = {} - similarities = {} - sorted_similarities = {} mfcc = None - slop = False # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier - if storage_backend == "faiss": - slop = True - elif storage_backend == "chunkdot": - slop = True - elif storage_backend == "slop": - slop = True + simplified_metadata = True # aims to slim down the raw data in the JSON to store + slop = True # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier # compute features (embeddings if quantized already, MFCC features if raw audio) - for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose): + for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path.name}'", disable=not verbose): extension = filename.split(".")[-1] filename = filename.replace(f".{extension}", "") @@ -123,11 +115,13 @@ def process( artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()] duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"] + """ if 0 < min_duration and duration < min_duration: continue if 0 < max_duration and max_duration < duration: continue + """ qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device) @@ -142,11 +136,13 @@ def process( duration = wav.shape[-1] / sr + """ if 0 < min_duration and duration < min_duration: continue if 0 < max_duration and max_duration < duration: continue + """ if mfcc is None: mfcc = T.MFCC(sample_rate=cfg.sample_rate) @@ -175,119 +171,20 @@ def process( sim = list(I[0][1:]) print(f'{filename}: {sim}') """ - - if metadata_path is not None: - faiss.write_index(index, str(metadata_path.with_suffix(".faiss"))) - return - - """ - # to-do: actually refine this, maybe - # desu it's not super easy to install with python3.12, and it is slower than FAISS in testing............ - if storage_backend == "chunkdot": - from chunkdot import cosine_similarity_top_k - - embeddings = torch.stack( list( features.values() ) ).cpu().numpy() - similarities = cosine_similarity_top_k(embeddings, top_k=8, show_progress=verbose) - - print(similarities) - return - """ - - metadata = None - if metadata_path is not None: - metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else None + return index + + # do batch cosine similarity processing keys = list(features.keys()) + embeddings = torch.stack( list( features.values() ) ) + sorted_similarities = {} - # do batch cosine similarity processing - if slop: - embeddings = torch.stack( list( features.values() ) ) - sorted_similarities = {} - - for index, filename in enumerate(keys): - embedding = features[filename].unsqueeze(0) - similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist() - similarities = sorted([ ( keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True) - - sorted_similarities[filename] = similarities - - most_filename, most_score = similarities[0] - least_filename, least_score = similarities[-1] - - if metadata is not None: - if filename not in metadata: - metadata[filename] = {} - metadata[filename]["similar"] = similarities - - if verbose: - print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' ) - - if metadata is not None: - with open(str(metadata_path), "w", encoding="utf-8") as f: - f.write( truncate_json( json.dumps( metadata ) ) ) - - return sorted_similarities - - # an EXTREMELY naive implementation, fucking disgusting - queue = list(combinations(range(len(keys)), 2)) - for key in tqdm(queue, desc="Computing similarities", disable=not verbose): - index_a, index_b = key - filename_a, filename_b = keys[index_a], keys[index_b] - swapped_key = (index_b, index_a) - - if swapped_key in similarities: - similarities[key] = similarities[swapped_key] - continue - - if slop: - embedding_a = features[filename_a] - embedding_b = features[filename_b] - - similarity = torch.nn.functional.cosine_similarity(embedding_a, embedding_b, dim=0).mean().item() - else: - shortest = min( features[filename_a].shape[0], features[filename_b].shape[0] ) - embedding_a = features[filename_a][:shortest, :] - embedding_b = features[filename_b][:shortest, :] - - similarity = torch.nn.functional.cosine_similarity(embedding_a, embedding_b, dim=1).mean().item() - - similarities[key] = similarity - - # combinations() doesn't have swapped keys - if swapped_key not in similarities: - similarities[swapped_key] = similarity - - if index_a not in sorted_similarities: - sorted_similarities[index_a] = {} - if index_b not in sorted_similarities[index_a]: - sorted_similarities[index_a][index_b] = similarity - - if index_b not in sorted_similarities: - sorted_similarities[index_b] = {} - if index_a not in sorted_similarities[index_b]: - sorted_similarities[index_b][index_a] = similarity - - # sort similarities scores - for key, sorted_similarity in sorted_similarities.items(): - sorted_similarities[key] = sorted([ ( key, similarity ) for key, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True) - - most_filename, most_score = sorted_similarities[key][0] - least_filename, least_score = sorted_similarities[key][-1] - - filename = keys[key] - - if metadata is not None: - if filename not in metadata: - metadata[filename] = {} - metadata[filename]["similar"] = sorted_similarities[key] - - #if verbose: - # print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' ) - - if metadata is not None: - with open(str(metadata_path), "w", encoding="utf-8") as f: - f.write( truncate_json( json.dumps( metadata ) ) ) + for filename in tqdm(keys, desc=f"Computing similarities: {speaker_path.name}"): + embedding = features[filename].unsqueeze(0) + similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist() + sorted_similarities[filename] = similarities + #sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True) return sorted_similarities @@ -299,9 +196,11 @@ def main(): parser.add_argument("--yaml", type=Path) parser.add_argument("--text", action="store_true") + """ parser.add_argument("--trim-duration", type=float, default=3.0) parser.add_argument("--min-duration", type=float, default=0) parser.add_argument("--max-duration", type=float, default=0) + """ parser.add_argument("--storage-backend", type=str, default="slop") parser.add_argument("--audio-backend", type=str, default="encodec") @@ -320,15 +219,16 @@ def main(): speaker_name = name if "LibriTTS-R" in speaker_name: speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox") + + metadata_path = cfg.metadata_dir / f'{speaker_name}.json' - process( + similarities = process( speaker_path=cfg.data_dir / speaker_name, - metadata_path=cfg.metadata_dir / f'{speaker_name}.faiss', yaml=args.yaml, text=args.text, - trim_duration=args.trim_duration, - min_duration=args.min_duration, - max_duration=args.max_duration, + #trim_duration=args.trim_duration, + #min_duration=args.min_duration, + #max_duration=args.max_duration, storage_backend=args.storage_backend, audio_backend=args.audio_backend, @@ -339,6 +239,23 @@ def main(): verbose=True, ) + if args.storage_backend == "faiss": + faiss.write_index(similarities, str(metadata_path.with_suffix(".faiss"))) + return + + metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else {} + metadata_keys = list(metadata.keys()) if metadata else list(similarities.keys()) + + for filename, sim in similarities.items(): + if filename not in metadata: + metadata[filename] = {} + + metadata[filename]["similar"] = sim + + with open(str(metadata_path), "w", encoding="utf-8") as f: + f.write( json.dumps( metadata ) ) + #f.write( truncate_json( json.dumps( metadata ) ) ) + # training for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"): add( data_dir, type="training" ) @@ -356,9 +273,10 @@ def main(): speaker_path=args.input_speaker, yaml=args.yaml, text=args.text, - trim_duration=args.trim_duration, - min_duration=args.min_duration, - max_duration=args.max_duration, + + #trim_duration=args.trim_duration, + #min_duration=args.min_duration, + #max_duration=args.max_duration, audio_backend=args.audio_backend, device=args.device,