*faster*
This commit is contained in:
parent
f00283440c
commit
6ceed866b5
|
@ -4,7 +4,7 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import orjson as json
|
||||
import argparse
|
||||
import torch
|
||||
import torchaudio
|
||||
|
@ -53,6 +53,7 @@ def process(
|
|||
|
||||
verbose=False,
|
||||
metadata_path=None,
|
||||
top_k=8,
|
||||
|
||||
trim_duration=0,
|
||||
min_duration=0,
|
||||
|
@ -167,7 +168,7 @@ def process(
|
|||
# to-do: support just querying for list of similar to cram into JSON metadata
|
||||
if verbose:
|
||||
for filename, embedding in features.items():
|
||||
D, I = index.search(embedding.unsqueeze(0).cpu(), k=2)
|
||||
D, I = index.search(embedding.unsqueeze(0).cpu(), k=top_k+1)
|
||||
sim = list(I[0][1:])
|
||||
print(f'{filename}: {sim}')
|
||||
"""
|
||||
|
@ -180,10 +181,18 @@ def process(
|
|||
embeddings = torch.stack( list( features.values() ) )
|
||||
sorted_similarities = {}
|
||||
|
||||
for filename in tqdm(keys, desc=f"Computing similarities: {speaker_path.name}"):
|
||||
for index, filename in tqdm(enumerate(keys), total=len(keys), desc=f"Computing similarities: {speaker_path.name}"):
|
||||
embedding = features[filename].unsqueeze(0)
|
||||
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
|
||||
|
||||
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1)
|
||||
# set current index to -inf
|
||||
similarities[index] = float("-inf")
|
||||
similarities = torch.topk(similarities, k=top_k, largest=True, sorted=True).indices.tolist()
|
||||
# similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
|
||||
|
||||
sorted_similarities[filename] = similarities
|
||||
|
||||
# sorting is slow, don't bother
|
||||
#sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
|
||||
|
||||
return sorted_similarities
|
||||
|
@ -196,17 +205,19 @@ def main():
|
|||
|
||||
parser.add_argument("--yaml", type=Path)
|
||||
parser.add_argument("--text", action="store_true")
|
||||
# dropped, because this might mess with the indices to map to
|
||||
"""
|
||||
parser.add_argument("--trim-duration", type=float, default=3.0)
|
||||
parser.add_argument("--min-duration", type=float, default=0)
|
||||
parser.add_argument("--max-duration", type=float, default=0)
|
||||
"""
|
||||
parser.add_argument("--storage-backend", type=str, default="slop")
|
||||
parser.add_argument("--top-k", type=int, default=8)
|
||||
|
||||
parser.add_argument("--audio-backend", type=str, default="encodec")
|
||||
parser.add_argument("--dtype", type=str, default="float32")
|
||||
parser.add_argument("--dtype", type=str, default="float16")
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--device", type=str, default="cpu") # unironically faster
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -226,6 +237,7 @@ def main():
|
|||
speaker_path=cfg.data_dir / speaker_name,
|
||||
yaml=args.yaml,
|
||||
text=args.text,
|
||||
top_k=args.top_k,
|
||||
#trim_duration=args.trim_duration,
|
||||
#min_duration=args.min_duration,
|
||||
#max_duration=args.max_duration,
|
||||
|
@ -252,7 +264,7 @@ def main():
|
|||
|
||||
metadata[filename]["similar"] = sim
|
||||
|
||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||
with open(str(metadata_path), "wb") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
#f.write( truncate_json( json.dumps( metadata ) ) )
|
||||
|
||||
|
@ -273,6 +285,7 @@ def main():
|
|||
speaker_path=args.input_speaker,
|
||||
yaml=args.yaml,
|
||||
text=args.text,
|
||||
top_k=args.top_k,
|
||||
|
||||
#trim_duration=args.trim_duration,
|
||||
#min_duration=args.min_duration,
|
||||
|
|
Loading…
Reference in New Issue
Block a user