optimizations (6 hours to do cosine similarities on a speaker set of just 17k utterances................)

This commit is contained in:
mrq 2024-09-17 15:51:45 -05:00
parent a9fbe81f98
commit 804ddb5182

View File

@ -11,6 +11,8 @@ import torchaudio
import numpy as np
import logging
from itertools import combinations
_logger = logging.getLogger(__name__)
from tqdm.auto import tqdm
@ -68,7 +70,7 @@ def process(
mfcc = T.MFCC(sample_rate=cfg.sample_rate)
# compute features (embeddings if quantized already, MFCC features if raw audio)
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc="Encoding...", disable=not verbose):
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose):
extension = filename.split(".")[-1]
filename = filename.replace(f".{extension}", "")
@ -104,43 +106,35 @@ def process(
wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' )
features[filename] = mfcc(wav.to(device))[0].t()
# calculate pairs, flattened because it makes tqdm nicer
for filename_a, embedding_a in features.items():
for filename_b, embedding_b in features.items():
if filename_a == filename_b:
continue
keys = list(features.keys())
key_range = range(len(keys))
# queue = [ (index_a, index_b) for index_b in key_range for index_a in key_range if index_a != index_b ]
queue = list(combinations(key_range, 2))
key = f'{filename_a}:{filename_b}'
if key in queue:
continue
queue.append(key)
# compute similarities for every utternace
for key in tqdm(queue, desc="Computing similarities", disable=not verbose):
filename_a, filename_b = key.split(":")
swapped_key = f'{filename_b}:{filename_a}'
index_a, index_b = key
filename_a, filename_b = keys[index_a], keys[index_b]
swapped_key = (index_b, index_a)
if swapped_key in similarities:
similarities[key] = similarities[swapped_key]
continue
shortest = min( features[filename_a].shape[0], features[filename_b].shape[0] )
similarities[key] = torch.nn.functional.cosine_similarity(features[filename_a][:shortest, :], features[filename_b][:shortest, :], dim=1).mean().item()
# ???
for key, similarity in similarities.items():
filename_a, filename_b = key.split(":")
if filename_a not in sorted_similarities:
sorted_similarities[filename_a] = {}
if filename_b not in sorted_similarities[filename_a]:
sorted_similarities[filename_a][filename_b] = similarity
similarity = torch.nn.functional.cosine_similarity(features[filename_a][:shortest, :], features[filename_b][:shortest, :], dim=1).mean().item()
if filename_b not in sorted_similarities:
sorted_similarities[filename_b] = {}
if filename_a not in sorted_similarities[filename_b]:
sorted_similarities[filename_b][filename_a] = similarity
similarities[key] = similarity
if index_a not in sorted_similarities:
sorted_similarities[index_a] = {}
if index_b not in sorted_similarities[index_a]:
sorted_similarities[index_a][index_b] = similarity
if index_b not in sorted_similarities:
sorted_similarities[index_b] = {}
if index_a not in sorted_similarities[index_b]:
sorted_similarities[index_b][index_a] = similarity
metadata = None
if metadata_path is not None:
@ -150,19 +144,21 @@ def process(
metadata = {}
# sort similarities scores
for filename, sorted_similarity in sorted_similarities.items():
sorted_similarities[filename] = sorted([ ( filename, similarity ) for filename, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True)
for key, sorted_similarity in sorted_similarities.items():
sorted_similarities[key] = sorted([ ( key, similarity ) for key, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True)
most_filename, most_score = sorted_similarities[filename][0]
least_filename, least_score = sorted_similarities[filename][-1]
most_filename, most_score = sorted_similarities[key][0]
least_filename, least_score = sorted_similarities[key][-1]
filename = keys[key]
if metadata is not None:
if filename not in metadata:
metadata[filename] = {}
metadata[filename]["similar"] = sorted_similarities[filename]
metadata[filename]["similar"] = sorted_similarities[key]
if verbose:
print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
#if verbose:
# print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
if metadata is not None:
with open(str(metadata_path), "w", encoding="utf-8") as f:
@ -209,7 +205,7 @@ def main():
dtype=args.dtype,
amp=args.amp,
verbose=False,
verbose=True,
)
# training