optimizations (6 hours to do cosine similarities on a speaker set of just 17k utterances................)

This commit is contained in:
mrq 2024-09-17 15:51:45 -05:00
parent a9fbe81f98
commit 804ddb5182

View File

@ -11,6 +11,8 @@ import torchaudio
import numpy as np import numpy as np
import logging import logging
from itertools import combinations
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
from tqdm.auto import tqdm from tqdm.auto import tqdm
@ -68,7 +70,7 @@ def process(
mfcc = T.MFCC(sample_rate=cfg.sample_rate) mfcc = T.MFCC(sample_rate=cfg.sample_rate)
# compute features (embeddings if quantized already, MFCC features if raw audio) # compute features (embeddings if quantized already, MFCC features if raw audio)
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc="Encoding...", disable=not verbose): for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose):
extension = filename.split(".")[-1] extension = filename.split(".")[-1]
filename = filename.replace(f".{extension}", "") filename = filename.replace(f".{extension}", "")
@ -104,43 +106,35 @@ def process(
wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' ) wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' )
features[filename] = mfcc(wav.to(device))[0].t() features[filename] = mfcc(wav.to(device))[0].t()
# calculate pairs, flattened because it makes tqdm nicer keys = list(features.keys())
for filename_a, embedding_a in features.items(): key_range = range(len(keys))
for filename_b, embedding_b in features.items(): # queue = [ (index_a, index_b) for index_b in key_range for index_a in key_range if index_a != index_b ]
if filename_a == filename_b: queue = list(combinations(key_range, 2))
continue
key = f'{filename_a}:{filename_b}'
if key in queue:
continue
queue.append(key)
# compute similarities for every utternace # compute similarities for every utternace
for key in tqdm(queue, desc="Computing similarities", disable=not verbose): for key in tqdm(queue, desc="Computing similarities", disable=not verbose):
filename_a, filename_b = key.split(":") index_a, index_b = key
swapped_key = f'{filename_b}:{filename_a}' filename_a, filename_b = keys[index_a], keys[index_b]
swapped_key = (index_b, index_a)
if swapped_key in similarities: if swapped_key in similarities:
similarities[key] = similarities[swapped_key] similarities[key] = similarities[swapped_key]
continue continue
shortest = min( features[filename_a].shape[0], features[filename_b].shape[0] ) shortest = min( features[filename_a].shape[0], features[filename_b].shape[0] )
similarities[key] = torch.nn.functional.cosine_similarity(features[filename_a][:shortest, :], features[filename_b][:shortest, :], dim=1).mean().item() similarity = torch.nn.functional.cosine_similarity(features[filename_a][:shortest, :], features[filename_b][:shortest, :], dim=1).mean().item()
# ??? similarities[key] = similarity
for key, similarity in similarities.items():
filename_a, filename_b = key.split(":")
if filename_a not in sorted_similarities: if index_a not in sorted_similarities:
sorted_similarities[filename_a] = {} sorted_similarities[index_a] = {}
if filename_b not in sorted_similarities[filename_a]: if index_b not in sorted_similarities[index_a]:
sorted_similarities[filename_a][filename_b] = similarity sorted_similarities[index_a][index_b] = similarity
if filename_b not in sorted_similarities: if index_b not in sorted_similarities:
sorted_similarities[filename_b] = {} sorted_similarities[index_b] = {}
if filename_a not in sorted_similarities[filename_b]: if index_a not in sorted_similarities[index_b]:
sorted_similarities[filename_b][filename_a] = similarity sorted_similarities[index_b][index_a] = similarity
metadata = None metadata = None
if metadata_path is not None: if metadata_path is not None:
@ -150,19 +144,21 @@ def process(
metadata = {} metadata = {}
# sort similarities scores # sort similarities scores
for filename, sorted_similarity in sorted_similarities.items(): for key, sorted_similarity in sorted_similarities.items():
sorted_similarities[filename] = sorted([ ( filename, similarity ) for filename, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True) sorted_similarities[key] = sorted([ ( key, similarity ) for key, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True)
most_filename, most_score = sorted_similarities[filename][0] most_filename, most_score = sorted_similarities[key][0]
least_filename, least_score = sorted_similarities[filename][-1] least_filename, least_score = sorted_similarities[key][-1]
filename = keys[key]
if metadata is not None: if metadata is not None:
if filename not in metadata: if filename not in metadata:
metadata[filename] = {} metadata[filename] = {}
metadata[filename]["similar"] = sorted_similarities[filename] metadata[filename]["similar"] = sorted_similarities[key]
if verbose: #if verbose:
print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' ) # print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
if metadata is not None: if metadata is not None:
with open(str(metadata_path), "w", encoding="utf-8") as f: with open(str(metadata_path), "w", encoding="utf-8") as f:
@ -209,7 +205,7 @@ def main():
dtype=args.dtype, dtype=args.dtype,
amp=args.amp, amp=args.amp,
verbose=False, verbose=True,
) )
# training # training