diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index d3afea0..a8617a9 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -24,7 +24,6 @@ import torchaudio.transforms as T from ..config import cfg from ..utils import truncate_json -# need to validate if this is safe to import before modifying the config from .g2p import encode as phonemize from .qnt import encode as quantize, trim, convert_audio @@ -55,8 +54,11 @@ def process( verbose=False, metadata_path=None, - maximum_duration=0, - #use_faiss=True, + trim_duration=0, + min_duration=0, + max_duration=0, + + storage_backend="local" ): global tts @@ -75,14 +77,15 @@ def process( similarities = {} sorted_similarities = {} - mfcc = T.MFCC(sample_rate=cfg.sample_rate) + mfcc = None - """ - # too slow - if use_faiss: - import faiss - index = None - """ + slop = False # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier + if storage_backend == "faiss": + slop = True + elif storage_backend == "chunkdot": + slop = True + elif storage_backend == "slop": + slop = True # compute features (embeddings if quantized already, MFCC features if raw audio) for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose): @@ -94,6 +97,13 @@ def process( raise Exception("!") artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()] + duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"] + + if 0 < min_duration and duration < min_duration: + continue + + if 0 < max_duration and max_duration < duration: + continue lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en" if "phonemes" in artifact["metadata"]: @@ -111,42 +121,116 @@ def process( # treat embeddings as features, if provided quantized audio if extension in artifact_extension: artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()] + duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"] + + if 0 < min_duration and duration < min_duration: + continue + + if 0 < max_duration and max_duration < duration: + continue + qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device) - if maximum_duration > 0: - qnt = trim( qnt, int( cfg.dataset.frames_per_second * maximum_duration ) ) + + if trim_duration > 0: + qnt = trim( qnt, int( cfg.dataset.frames_per_second * trim_duration ) ) embedding = tts.audio_embedding( qnt ) # try and extract features from the raw audio itself else: # qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device) wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' ) + + duration = wav.shape[-1] / sr + + if 0 < min_duration and duration < min_duration: + continue + + if 0 < max_duration and max_duration < duration: + continue + + if mfcc is None: + mfcc = T.MFCC(sample_rate=cfg.sample_rate) + embedding = mfcc(wav.to(device))[0].t() + if slop: + embedding = embedding.sum(dim=0) + features[filename] = embedding + + # rely on FAISS to handle storing embeddings and handling queries + # will probably explode in size fast........... + if storage_backend == "faiss": + import faiss + index = faiss.IndexFlatL2( embeddings.shape[-1] ) + embeddings = torch.stack( list( features.values() ) ).cpu() + index.add( embeddings ) + """ - if use_faiss: - if index is None: - shape = embedding.shape - index = faiss.IndexFlatL2(shape[1]) - - index.add(embedding.cpu()) - + # to-do: support just querying for list of similar to cram into JSON metadata if verbose: for filename, embedding in features.items(): - D, I = index.search(embedding.cpu(), k=3) - # print(f'{filename}: {I[1:]}') - - if metadata_path is not None: - index.save(metadata_path) + D, I = index.search(embedding.unsqueeze(0).cpu(), k=2) + sim = list(I[0][1:]) + print(f'{filename}: {sim}') """ - keys = list(features.keys()) - key_range = range(len(keys)) - # queue = [ (index_a, index_b) for index_b in key_range for index_a in key_range if index_a != index_b ] - queue = list(combinations(key_range, 2)) + if metadata_path is not None: + faiss.write_index(index, str(metadata_path.with_suffix(".faiss"))) + + return - # compute similarities for every utternace + """ + # to-do: actually refine this, maybe + # desu it's not super easy to install with python3.12, and it is slower than FAISS in testing............ + if storage_backend == "chunkdot": + from chunkdot import cosine_similarity_top_k + + embeddings = torch.stack( list( features.values() ) ).cpu().numpy() + similarities = cosine_similarity_top_k(embeddings, top_k=8, show_progress=verbose) + + print(similarities) + return + """ + + metadata = None + if metadata_path is not None: + metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else None + + keys = list(features.keys()) + + # do batch cosine similarity processing + if slop: + embeddings = torch.stack( list( features.values() ) ) + sorted_similarities = {} + + for index, filename in enumerate(keys): + embedding = features[filename].unsqueeze(0) + similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist() + similarities = sorted([ ( keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True) + + sorted_similarities[filename] = similarities + + most_filename, most_score = similarities[0] + least_filename, least_score = similarities[-1] + + if metadata is not None: + if filename not in metadata: + metadata[filename] = {} + metadata[filename]["similar"] = similarities + + if verbose: + print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' ) + + if metadata is not None: + with open(str(metadata_path), "w", encoding="utf-8") as f: + f.write( truncate_json( json.dumps( metadata ) ) ) + + return sorted_similarities + + # an EXTREMELY naive implementation, fucking disgusting + queue = list(combinations(range(len(keys)), 2)) for key in tqdm(queue, desc="Computing similarities", disable=not verbose): index_a, index_b = key filename_a, filename_b = keys[index_a], keys[index_b] @@ -156,8 +240,17 @@ def process( similarities[key] = similarities[swapped_key] continue - shortest = min( features[filename_a].shape[0], features[filename_b].shape[0] ) - similarity = torch.nn.functional.cosine_similarity(features[filename_a][:shortest, :], features[filename_b][:shortest, :], dim=1).mean().item() + if slop: + embedding_a = features[filename_a] + embedding_b = features[filename_b] + + similarity = torch.nn.functional.cosine_similarity(embedding_a, embedding_b, dim=0).mean().item() + else: + shortest = min( features[filename_a].shape[0], features[filename_b].shape[0] ) + embedding_a = features[filename_a][:shortest, :] + embedding_b = features[filename_b][:shortest, :] + + similarity = torch.nn.functional.cosine_similarity(embedding_a, embedding_b, dim=1).mean().item() similarities[key] = similarity @@ -175,13 +268,6 @@ def process( if index_a not in sorted_similarities[index_b]: sorted_similarities[index_b][index_a] = similarity - metadata = None - if metadata_path is not None: - if metadata_path.exists(): - metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) - else: - metadata = {} - # sort similarities scores for key, sorted_similarity in sorted_similarities.items(): sorted_similarities[key] = sorted([ ( key, similarity ) for key, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True) @@ -201,9 +287,7 @@ def process( if metadata is not None: with open(str(metadata_path), "w", encoding="utf-8") as f: - serialized = json.dumps( metadata ) - serialized = truncate_json( serialized ) - f.write( serialized ) + f.write( truncate_json( json.dumps( metadata ) ) ) return sorted_similarities @@ -215,10 +299,13 @@ def main(): parser.add_argument("--yaml", type=Path) parser.add_argument("--text", action="store_true") - parser.add_argument("--maximum-duration", type=float, default=3.0) + parser.add_argument("--trim-duration", type=float, default=3.0) + parser.add_argument("--min-duration", type=float, default=0) + parser.add_argument("--max-duration", type=float, default=0) + parser.add_argument("--storage-backend", type=str, default="slop") parser.add_argument("--audio-backend", type=str, default="encodec") - parser.add_argument("--dtype", type=str, default="float16") + parser.add_argument("--dtype", type=str, default="float32") parser.add_argument("--amp", action="store_true") parser.add_argument("--device", type=str, default="cpu") # unironically faster @@ -236,10 +323,13 @@ def main(): process( speaker_path=cfg.data_dir / speaker_name, - metadata_path=cfg.metadata_dir / f'{speaker_name}.json', + metadata_path=cfg.metadata_dir / f'{speaker_name}.faiss', yaml=args.yaml, text=args.text, - maximum_duration=args.maximum_duration, + trim_duration=args.trim_duration, + min_duration=args.min_duration, + max_duration=args.max_duration, + storage_backend=args.storage_backend, audio_backend=args.audio_backend, device=args.device, @@ -260,18 +350,22 @@ def main(): # noise for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'): add( data_dir, type="noise", texts=False ) + elif args.input_speaker: process( speaker_path=args.input_speaker, yaml=args.yaml, text=args.text, - maximum_duration=args.maximum_duration, + trim_duration=args.trim_duration, + min_duration=args.min_duration, + max_duration=args.max_duration, audio_backend=args.audio_backend, device=args.device, dtype=args.dtype, amp=args.amp, + storage_backend=args.storage_backend, verbose=True, ) else: