*faster*
This commit is contained in:
parent
f00283440c
commit
6ceed866b5
|
@ -4,7 +4,7 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import orjson as json
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
@ -53,6 +53,7 @@ def process(
|
||||||
|
|
||||||
verbose=False,
|
verbose=False,
|
||||||
metadata_path=None,
|
metadata_path=None,
|
||||||
|
top_k=8,
|
||||||
|
|
||||||
trim_duration=0,
|
trim_duration=0,
|
||||||
min_duration=0,
|
min_duration=0,
|
||||||
|
@ -167,7 +168,7 @@ def process(
|
||||||
# to-do: support just querying for list of similar to cram into JSON metadata
|
# to-do: support just querying for list of similar to cram into JSON metadata
|
||||||
if verbose:
|
if verbose:
|
||||||
for filename, embedding in features.items():
|
for filename, embedding in features.items():
|
||||||
D, I = index.search(embedding.unsqueeze(0).cpu(), k=2)
|
D, I = index.search(embedding.unsqueeze(0).cpu(), k=top_k+1)
|
||||||
sim = list(I[0][1:])
|
sim = list(I[0][1:])
|
||||||
print(f'{filename}: {sim}')
|
print(f'{filename}: {sim}')
|
||||||
"""
|
"""
|
||||||
|
@ -180,10 +181,18 @@ def process(
|
||||||
embeddings = torch.stack( list( features.values() ) )
|
embeddings = torch.stack( list( features.values() ) )
|
||||||
sorted_similarities = {}
|
sorted_similarities = {}
|
||||||
|
|
||||||
for filename in tqdm(keys, desc=f"Computing similarities: {speaker_path.name}"):
|
for index, filename in tqdm(enumerate(keys), total=len(keys), desc=f"Computing similarities: {speaker_path.name}"):
|
||||||
embedding = features[filename].unsqueeze(0)
|
embedding = features[filename].unsqueeze(0)
|
||||||
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
|
|
||||||
|
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1)
|
||||||
|
# set current index to -inf
|
||||||
|
similarities[index] = float("-inf")
|
||||||
|
similarities = torch.topk(similarities, k=top_k, largest=True, sorted=True).indices.tolist()
|
||||||
|
# similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
|
||||||
|
|
||||||
sorted_similarities[filename] = similarities
|
sorted_similarities[filename] = similarities
|
||||||
|
|
||||||
|
# sorting is slow, don't bother
|
||||||
#sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
|
#sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
return sorted_similarities
|
return sorted_similarities
|
||||||
|
@ -196,17 +205,19 @@ def main():
|
||||||
|
|
||||||
parser.add_argument("--yaml", type=Path)
|
parser.add_argument("--yaml", type=Path)
|
||||||
parser.add_argument("--text", action="store_true")
|
parser.add_argument("--text", action="store_true")
|
||||||
|
# dropped, because this might mess with the indices to map to
|
||||||
"""
|
"""
|
||||||
parser.add_argument("--trim-duration", type=float, default=3.0)
|
parser.add_argument("--trim-duration", type=float, default=3.0)
|
||||||
parser.add_argument("--min-duration", type=float, default=0)
|
parser.add_argument("--min-duration", type=float, default=0)
|
||||||
parser.add_argument("--max-duration", type=float, default=0)
|
parser.add_argument("--max-duration", type=float, default=0)
|
||||||
"""
|
"""
|
||||||
parser.add_argument("--storage-backend", type=str, default="slop")
|
parser.add_argument("--storage-backend", type=str, default="slop")
|
||||||
|
parser.add_argument("--top-k", type=int, default=8)
|
||||||
|
|
||||||
parser.add_argument("--audio-backend", type=str, default="encodec")
|
parser.add_argument("--audio-backend", type=str, default="encodec")
|
||||||
parser.add_argument("--dtype", type=str, default="float32")
|
parser.add_argument("--dtype", type=str, default="float16")
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
parser.add_argument("--device", type=str, default="cpu") # unironically faster
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -226,6 +237,7 @@ def main():
|
||||||
speaker_path=cfg.data_dir / speaker_name,
|
speaker_path=cfg.data_dir / speaker_name,
|
||||||
yaml=args.yaml,
|
yaml=args.yaml,
|
||||||
text=args.text,
|
text=args.text,
|
||||||
|
top_k=args.top_k,
|
||||||
#trim_duration=args.trim_duration,
|
#trim_duration=args.trim_duration,
|
||||||
#min_duration=args.min_duration,
|
#min_duration=args.min_duration,
|
||||||
#max_duration=args.max_duration,
|
#max_duration=args.max_duration,
|
||||||
|
@ -252,7 +264,7 @@ def main():
|
||||||
|
|
||||||
metadata[filename]["similar"] = sim
|
metadata[filename]["similar"] = sim
|
||||||
|
|
||||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
with open(str(metadata_path), "wb") as f:
|
||||||
f.write( json.dumps( metadata ) )
|
f.write( json.dumps( metadata ) )
|
||||||
#f.write( truncate_json( json.dumps( metadata ) ) )
|
#f.write( truncate_json( json.dumps( metadata ) ) )
|
||||||
|
|
||||||
|
@ -273,6 +285,7 @@ def main():
|
||||||
speaker_path=args.input_speaker,
|
speaker_path=args.input_speaker,
|
||||||
yaml=args.yaml,
|
yaml=args.yaml,
|
||||||
text=args.text,
|
text=args.text,
|
||||||
|
top_k=args.top_k,
|
||||||
|
|
||||||
#trim_duration=args.trim_duration,
|
#trim_duration=args.trim_duration,
|
||||||
#min_duration=args.min_duration,
|
#min_duration=args.min_duration,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user