faster
This commit is contained in:
parent
be22b65300
commit
f00283440c
|
@ -58,7 +58,7 @@ def process(
|
|||
min_duration=0,
|
||||
max_duration=0,
|
||||
|
||||
storage_backend="local"
|
||||
storage_backend="slop"
|
||||
):
|
||||
global tts
|
||||
|
||||
|
@ -72,23 +72,15 @@ def process(
|
|||
if tts is None:
|
||||
tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype )
|
||||
|
||||
queue = []
|
||||
features = {}
|
||||
similarities = {}
|
||||
sorted_similarities = {}
|
||||
|
||||
mfcc = None
|
||||
|
||||
slop = False # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier
|
||||
if storage_backend == "faiss":
|
||||
slop = True
|
||||
elif storage_backend == "chunkdot":
|
||||
slop = True
|
||||
elif storage_backend == "slop":
|
||||
slop = True
|
||||
simplified_metadata = True # aims to slim down the raw data in the JSON to store
|
||||
slop = True # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier
|
||||
|
||||
# compute features (embeddings if quantized already, MFCC features if raw audio)
|
||||
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose):
|
||||
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path.name}'", disable=not verbose):
|
||||
extension = filename.split(".")[-1]
|
||||
filename = filename.replace(f".{extension}", "")
|
||||
|
||||
|
@ -123,11 +115,13 @@ def process(
|
|||
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
|
||||
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
|
||||
|
||||
"""
|
||||
if 0 < min_duration and duration < min_duration:
|
||||
continue
|
||||
|
||||
if 0 < max_duration and max_duration < duration:
|
||||
continue
|
||||
"""
|
||||
|
||||
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
|
||||
|
||||
|
@ -142,11 +136,13 @@ def process(
|
|||
|
||||
duration = wav.shape[-1] / sr
|
||||
|
||||
"""
|
||||
if 0 < min_duration and duration < min_duration:
|
||||
continue
|
||||
|
||||
if 0 < max_duration and max_duration < duration:
|
||||
continue
|
||||
"""
|
||||
|
||||
if mfcc is None:
|
||||
mfcc = T.MFCC(sample_rate=cfg.sample_rate)
|
||||
|
@ -175,119 +171,20 @@ def process(
|
|||
sim = list(I[0][1:])
|
||||
print(f'{filename}: {sim}')
|
||||
"""
|
||||
|
||||
if metadata_path is not None:
|
||||
faiss.write_index(index, str(metadata_path.with_suffix(".faiss")))
|
||||
|
||||
return
|
||||
|
||||
"""
|
||||
# to-do: actually refine this, maybe
|
||||
# desu it's not super easy to install with python3.12, and it is slower than FAISS in testing............
|
||||
if storage_backend == "chunkdot":
|
||||
from chunkdot import cosine_similarity_top_k
|
||||
|
||||
embeddings = torch.stack( list( features.values() ) ).cpu().numpy()
|
||||
similarities = cosine_similarity_top_k(embeddings, top_k=8, show_progress=verbose)
|
||||
|
||||
print(similarities)
|
||||
return
|
||||
"""
|
||||
|
||||
metadata = None
|
||||
if metadata_path is not None:
|
||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else None
|
||||
return index
|
||||
|
||||
# do batch cosine similarity processing
|
||||
|
||||
keys = list(features.keys())
|
||||
embeddings = torch.stack( list( features.values() ) )
|
||||
sorted_similarities = {}
|
||||
|
||||
# do batch cosine similarity processing
|
||||
if slop:
|
||||
embeddings = torch.stack( list( features.values() ) )
|
||||
sorted_similarities = {}
|
||||
|
||||
for index, filename in enumerate(keys):
|
||||
embedding = features[filename].unsqueeze(0)
|
||||
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
|
||||
similarities = sorted([ ( keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
|
||||
|
||||
sorted_similarities[filename] = similarities
|
||||
|
||||
most_filename, most_score = similarities[0]
|
||||
least_filename, least_score = similarities[-1]
|
||||
|
||||
if metadata is not None:
|
||||
if filename not in metadata:
|
||||
metadata[filename] = {}
|
||||
metadata[filename]["similar"] = similarities
|
||||
|
||||
if verbose:
|
||||
print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
|
||||
|
||||
if metadata is not None:
|
||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||
f.write( truncate_json( json.dumps( metadata ) ) )
|
||||
|
||||
return sorted_similarities
|
||||
|
||||
# an EXTREMELY naive implementation, fucking disgusting
|
||||
queue = list(combinations(range(len(keys)), 2))
|
||||
for key in tqdm(queue, desc="Computing similarities", disable=not verbose):
|
||||
index_a, index_b = key
|
||||
filename_a, filename_b = keys[index_a], keys[index_b]
|
||||
swapped_key = (index_b, index_a)
|
||||
|
||||
if swapped_key in similarities:
|
||||
similarities[key] = similarities[swapped_key]
|
||||
continue
|
||||
|
||||
if slop:
|
||||
embedding_a = features[filename_a]
|
||||
embedding_b = features[filename_b]
|
||||
|
||||
similarity = torch.nn.functional.cosine_similarity(embedding_a, embedding_b, dim=0).mean().item()
|
||||
else:
|
||||
shortest = min( features[filename_a].shape[0], features[filename_b].shape[0] )
|
||||
embedding_a = features[filename_a][:shortest, :]
|
||||
embedding_b = features[filename_b][:shortest, :]
|
||||
|
||||
similarity = torch.nn.functional.cosine_similarity(embedding_a, embedding_b, dim=1).mean().item()
|
||||
|
||||
similarities[key] = similarity
|
||||
|
||||
# combinations() doesn't have swapped keys
|
||||
if swapped_key not in similarities:
|
||||
similarities[swapped_key] = similarity
|
||||
|
||||
if index_a not in sorted_similarities:
|
||||
sorted_similarities[index_a] = {}
|
||||
if index_b not in sorted_similarities[index_a]:
|
||||
sorted_similarities[index_a][index_b] = similarity
|
||||
|
||||
if index_b not in sorted_similarities:
|
||||
sorted_similarities[index_b] = {}
|
||||
if index_a not in sorted_similarities[index_b]:
|
||||
sorted_similarities[index_b][index_a] = similarity
|
||||
|
||||
# sort similarities scores
|
||||
for key, sorted_similarity in sorted_similarities.items():
|
||||
sorted_similarities[key] = sorted([ ( key, similarity ) for key, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True)
|
||||
|
||||
most_filename, most_score = sorted_similarities[key][0]
|
||||
least_filename, least_score = sorted_similarities[key][-1]
|
||||
|
||||
filename = keys[key]
|
||||
|
||||
if metadata is not None:
|
||||
if filename not in metadata:
|
||||
metadata[filename] = {}
|
||||
metadata[filename]["similar"] = sorted_similarities[key]
|
||||
|
||||
#if verbose:
|
||||
# print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
|
||||
|
||||
if metadata is not None:
|
||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||
f.write( truncate_json( json.dumps( metadata ) ) )
|
||||
for filename in tqdm(keys, desc=f"Computing similarities: {speaker_path.name}"):
|
||||
embedding = features[filename].unsqueeze(0)
|
||||
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
|
||||
sorted_similarities[filename] = similarities
|
||||
#sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
|
||||
|
||||
return sorted_similarities
|
||||
|
||||
|
@ -299,9 +196,11 @@ def main():
|
|||
|
||||
parser.add_argument("--yaml", type=Path)
|
||||
parser.add_argument("--text", action="store_true")
|
||||
"""
|
||||
parser.add_argument("--trim-duration", type=float, default=3.0)
|
||||
parser.add_argument("--min-duration", type=float, default=0)
|
||||
parser.add_argument("--max-duration", type=float, default=0)
|
||||
"""
|
||||
parser.add_argument("--storage-backend", type=str, default="slop")
|
||||
|
||||
parser.add_argument("--audio-backend", type=str, default="encodec")
|
||||
|
@ -320,15 +219,16 @@ def main():
|
|||
speaker_name = name
|
||||
if "LibriTTS-R" in speaker_name:
|
||||
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
||||
|
||||
metadata_path = cfg.metadata_dir / f'{speaker_name}.json'
|
||||
|
||||
process(
|
||||
similarities = process(
|
||||
speaker_path=cfg.data_dir / speaker_name,
|
||||
metadata_path=cfg.metadata_dir / f'{speaker_name}.faiss',
|
||||
yaml=args.yaml,
|
||||
text=args.text,
|
||||
trim_duration=args.trim_duration,
|
||||
min_duration=args.min_duration,
|
||||
max_duration=args.max_duration,
|
||||
#trim_duration=args.trim_duration,
|
||||
#min_duration=args.min_duration,
|
||||
#max_duration=args.max_duration,
|
||||
storage_backend=args.storage_backend,
|
||||
|
||||
audio_backend=args.audio_backend,
|
||||
|
@ -339,6 +239,23 @@ def main():
|
|||
verbose=True,
|
||||
)
|
||||
|
||||
if args.storage_backend == "faiss":
|
||||
faiss.write_index(similarities, str(metadata_path.with_suffix(".faiss")))
|
||||
return
|
||||
|
||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else {}
|
||||
metadata_keys = list(metadata.keys()) if metadata else list(similarities.keys())
|
||||
|
||||
for filename, sim in similarities.items():
|
||||
if filename not in metadata:
|
||||
metadata[filename] = {}
|
||||
|
||||
metadata[filename]["similar"] = sim
|
||||
|
||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
#f.write( truncate_json( json.dumps( metadata ) ) )
|
||||
|
||||
# training
|
||||
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
||||
add( data_dir, type="training" )
|
||||
|
@ -356,9 +273,10 @@ def main():
|
|||
speaker_path=args.input_speaker,
|
||||
yaml=args.yaml,
|
||||
text=args.text,
|
||||
trim_duration=args.trim_duration,
|
||||
min_duration=args.min_duration,
|
||||
max_duration=args.max_duration,
|
||||
|
||||
#trim_duration=args.trim_duration,
|
||||
#min_duration=args.min_duration,
|
||||
#max_duration=args.max_duration,
|
||||
|
||||
audio_backend=args.audio_backend,
|
||||
device=args.device,
|
||||
|
|
Loading…
Reference in New Issue
Block a user