This commit is contained in:
mrq 2025-02-10 21:47:19 -06:00
parent d6a679ca5c
commit f788faca80
2 changed files with 27 additions and 13 deletions

View File

@ -312,7 +312,7 @@ def process(
elif language == "chinese": elif language == "chinese":
language = "zh" language = "zh"
if strict_language and language not in ["en", "ja", "fr", "de", "ko", "zh"]: if strict_languages and language not in ["en", "ja", "fr", "de", "ko", "zh"]:
language = "auto" language = "auto"
if len(metadata[filename]["segments"]) == 0 or not use_slices: if len(metadata[filename]["segments"]) == 0 or not use_slices:

View File

@ -22,6 +22,7 @@ import torchaudio.functional as F
import torchaudio.transforms as T import torchaudio.transforms as T
from ..config import cfg from ..config import cfg
from ..data import _load_artifact
from ..utils import truncate_json, coerce_dtype from ..utils import truncate_json, coerce_dtype
from ..utils.io import json_read, json_write from ..utils.io import json_read, json_write
@ -171,10 +172,10 @@ def batch_similar_utterances(
if extension not in artifact_extension: if extension not in artifact_extension:
raise Exception("!") raise Exception("!")
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()] _, metadata = _load_artifact(f'./{speaker_path}/{filename}.{extension}', return_metadata=True)
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
""" """
duration = metadata["original_length"] / metadata["sample_rate"]
if 0 < min_duration and duration < min_duration: if 0 < min_duration and duration < min_duration:
continue continue
@ -182,11 +183,11 @@ def batch_similar_utterances(
continue continue
""" """
lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en" lang = metadata["language"] if "language" in metadata["language"] else "en"
if "phonemes" in artifact["metadata"]: if "phonemes" in metadata:
phn = artifact["metadata"]["phonemes"] phn = metadata["phonemes"]
elif "text" in artifact["metadata"]: elif "text" in metadata:
txt = artifact["metadata"]["text"] txt = metadata["text"]
phn = phonemize( txt, language=lang ) phn = phonemize( txt, language=lang )
phn = phn.replace("(en)", "") phn = phn.replace("(en)", "")
@ -198,10 +199,12 @@ def batch_similar_utterances(
# treat embeddings as features, if provided quantized audio # treat embeddings as features, if provided quantized audio
if extension not in artifact_extension: if extension not in artifact_extension:
continue continue
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"] qnt, metadata = _load_artifact(f'./{speaker_path}/{filename}.{extension}', return_metadata=True)
""" """
duration = metadata["original_length"] / metadata["sample_rate"]
if 0 < min_duration and duration < min_duration: if 0 < min_duration and duration < min_duration:
continue continue
@ -209,8 +212,6 @@ def batch_similar_utterances(
continue continue
""" """
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
if trim_duration > 0: if trim_duration > 0:
qnt = trim( qnt, int( cfg.dataset.frames_per_second * trim_duration ) ) qnt = trim( qnt, int( cfg.dataset.frames_per_second * trim_duration ) )
@ -307,10 +308,23 @@ def batch_similar_utterances(
""" """
def sort_similarities( def sort_similarities(
path, path,
num_speakers,
out_path=None, out_path=None,
threshold=0.8, threshold=0.8,
orphan_threshold=0.6, orphan_threshold=0.6,
): ):
from sklearn.cluster import KMeans
folders = [ "1", "2", "3", "4", "5", "6-7", "8", "9", "10", "11", "12", "14", "15" ]
embeddings = json_read(path / "0" / "embeddings.json")
for filename, embedding in embeddings.items():
embeddings[filename] = np.array(embedding)
embeddings_array = np.stack( list( embeddings.values() ) )
kmeans = KMeans(n_clusters=num_speakers).fit(embeddings_array)
"""
if not out_path: if not out_path:
out_path = path.parent / "speakers.json" out_path = path.parent / "speakers.json"
@ -371,7 +385,7 @@ def sort_similarities(
continue continue
speakers[target].append(filename) speakers[target].append(filename)
"""
json_write( speakers, out_path ) json_write( speakers, out_path )