This commit is contained in:
mrq 2025-02-09 13:02:51 -06:00
parent 953015748f
commit 075ffef68a
3 changed files with 76 additions and 43 deletions

View File

@ -111,9 +111,7 @@ def speaker_similarity_embedding(
def batch_similar_utterances( def batch_similar_utterances(
speaker_path, speaker_path,
yaml, yaml,
text=False,
audio_backend="encodec",
device="cuda", device="cuda",
dtype="float16", dtype="float16",
amp=False, amp=False,
@ -121,13 +119,18 @@ def batch_similar_utterances(
verbose=False, verbose=False,
metadata_path=None, metadata_path=None,
top_k=8, top_k=8,
top_p=0.5,
metadata_keys=[], metadata_keys=[],
trim_duration=0, trim_duration=0,
min_duration=0, min_duration=0,
max_duration=0, max_duration=0,
storage_backend="slop" audio_backend="encodec",
storage_backend="slop",
similarity_backend="resp",
return_features=False,
): ):
global tts global tts
@ -151,12 +154,18 @@ def batch_similar_utterances(
if not speaker_path.exists(): if not speaker_path.exists():
return return
# to-do: find decent thresholds
"""
if similarity_backend != "wavlm":
top_p = float("-inf")
"""
# compute features (embeddings if quantized already, MFCC features if raw audio) # compute features (embeddings if quantized already, MFCC features if raw audio)
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path.name}'", disable=not verbose): for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path.name}'", disable=not verbose):
extension = filename.split(".")[-1] extension = filename.split(".")[-1]
filename = filename.replace(f".{extension}", "") filename = filename.replace(f".{extension}", "")
if text: if similarity_backend == "text":
if extension not in artifact_extension: if extension not in artifact_extension:
raise Exception("!") raise Exception("!")
@ -183,34 +192,37 @@ def batch_similar_utterances(
phn = phn.replace(f"({metadata['language']})", "") phn = phn.replace(f"({metadata['language']})", "")
embedding = tts.text_embedding( phn ) embedding = tts.text_embedding( phn )
else: elif similarity_backend == "resp":
# treat embeddings as features, if provided quantized audio # treat embeddings as features, if provided quantized audio
if extension in artifact_extension: if extension not in artifact_extension:
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()] continue
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"] artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
""" """
if 0 < min_duration and duration < min_duration: if 0 < min_duration and duration < min_duration:
continue continue
if 0 < max_duration and max_duration < duration: if 0 < max_duration and max_duration < duration:
continue continue
""" """
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device) qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
if trim_duration > 0: if trim_duration > 0:
qnt = trim( qnt, int( cfg.dataset.frames_per_second * trim_duration ) ) qnt = trim( qnt, int( cfg.dataset.frames_per_second * trim_duration ) )
embedding = tts.audio_embedding( qnt ) embedding = tts.audio_embedding( qnt )
# try and extract features from the raw audio itself # try and extract features from the raw audio itself
else:
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
if similarity_backend == "wavlm":
embedding = speaker_similarity_embedding( f'./{speaker_path}/{filename}.{extension}' )
else: else:
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' ) wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' )
duration = wav.shape[-1] / sr
""" """
duration = wav.shape[-1] / sr
if 0 < min_duration and duration < min_duration: if 0 < min_duration and duration < min_duration:
continue continue
@ -223,6 +235,7 @@ def batch_similar_utterances(
embedding = mfcc(wav.to(device))[0].t() embedding = mfcc(wav.to(device))[0].t()
dim_shape = embedding.shape[-1]
if slop: if slop:
embedding = embedding.sum(dim=0) embedding = embedding.sum(dim=0)
@ -254,10 +267,10 @@ def batch_similar_utterances(
top_k = min( top_k, len(keys) ) top_k = min( top_k, len(keys) )
if top_k == 0: if top_k == 0:
return top_k = len(keys)
# fill any missing keys with a null embedding to keep the order the same # fill any missing keys with a null embedding to keep the order the same
null_embedding = torch.zeros( (1024,), device=tts.device, dtype=tts.dtype ) null_embedding = torch.zeros( (dim_shape,), device=tts.device, dtype=tts.dtype )
embeddings = torch.stack( [ feature if feature is not None else null_embedding for feature in features.values() ] ) embeddings = torch.stack( [ feature if feature is not None else null_embedding for feature in features.values() ] )
sorted_similarities = {} sorted_similarities = {}
@ -277,10 +290,12 @@ def batch_similar_utterances(
similarities[index] = float("-inf") similarities[index] = float("-inf")
topk = torch.topk(similarities, k=top_k, largest=True, sorted=True) topk = torch.topk(similarities, k=top_k, largest=True, sorted=True)
similarities = [ (index, keys[index], score) for index, score in zip( topk.indices.tolist(), topk.values.tolist() ) ] similarities = [ (index, keys[index], score) for index, score in zip( topk.indices.tolist(), topk.values.tolist() ) if score > top_p ]
sorted_similarities[filename] = similarities sorted_similarities[filename] = similarities
if return_features:
return sorted_similarities, features
return sorted_similarities return sorted_similarities
@ -292,17 +307,20 @@ def main():
parser.add_argument("--use-dataset", action="store_true") parser.add_argument("--use-dataset", action="store_true")
parser.add_argument("--yaml", type=Path) parser.add_argument("--yaml", type=Path)
parser.add_argument("--text", action="store_true") parser.add_argument("--out-path", type=Path, default=None)
# dropped, because this might mess with the indices to map to # dropped, because this might mess with the indices to map to
""" """
parser.add_argument("--trim-duration", type=float, default=3.0) parser.add_argument("--trim-duration", type=float, default=3.0)
parser.add_argument("--min-duration", type=float, default=0) parser.add_argument("--min-duration", type=float, default=0)
parser.add_argument("--max-duration", type=float, default=0) parser.add_argument("--max-duration", type=float, default=0)
""" """
parser.add_argument("--storage-backend", type=str, default="slop")
parser.add_argument("--top-k", type=int, default=8) parser.add_argument("--top-k", type=int, default=8)
parser.add_argument("--top-p", type=float, default=0.5)
parser.add_argument("--storage-backend", type=str, default="slop")
parser.add_argument("--similarity-backend", type=str, default="resp")
parser.add_argument("--audio-backend", type=str, default="encodec") parser.add_argument("--audio-backend", type=str, default="encodec")
parser.add_argument("--dtype", type=str, default="float16") parser.add_argument("--dtype", type=str, default="float16")
parser.add_argument("--amp", action="store_true") parser.add_argument("--amp", action="store_true")
parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--device", type=str, default="cuda")
@ -336,15 +354,17 @@ def main():
similarities = batch_similar_utterances( similarities = batch_similar_utterances(
speaker_path=cfg.data_dir / speaker_name, speaker_path=cfg.data_dir / speaker_name,
yaml=args.yaml, yaml=args.yaml,
text=args.text,
top_k=args.top_k, top_k=args.top_k,
top_p=args.top_p,
#trim_duration=args.trim_duration, #trim_duration=args.trim_duration,
#min_duration=args.min_duration, #min_duration=args.min_duration,
#max_duration=args.max_duration, #max_duration=args.max_duration,
audio_backend=args.audio_backend,
storage_backend=args.storage_backend, storage_backend=args.storage_backend,
similarity_backend=args.similarity_backend,
metadata_keys=metadata_keys, metadata_keys=metadata_keys,
audio_backend=args.audio_backend,
device=args.device, device=args.device,
dtype=args.dtype, dtype=args.dtype,
amp=args.amp, amp=args.amp,
@ -381,25 +401,36 @@ def main():
add( data_dir, type="noise", texts=False ) add( data_dir, type="noise", texts=False )
elif args.input_speaker: elif args.input_speaker:
similarities = batch_similar_utterances( similarities, features = batch_similar_utterances(
speaker_path=args.input_speaker, speaker_path=args.input_speaker,
yaml=args.yaml, yaml=args.yaml,
text=args.text,
top_k=args.top_k, top_k=args.top_k,
top_p=args.top_p,
#trim_duration=args.trim_duration, #trim_duration=args.trim_duration,
#min_duration=args.min_duration, #min_duration=args.min_duration,
#max_duration=args.max_duration, #max_duration=args.max_duration,
audio_backend=args.audio_backend,
device=args.device, device=args.device,
dtype=args.dtype, dtype=args.dtype,
amp=args.amp, amp=args.amp,
audio_backend=args.audio_backend,
storage_backend=args.storage_backend, storage_backend=args.storage_backend,
similarity_backend=args.similarity_backend,
verbose=True, verbose=True,
return_features=True,
) )
features_json = {}
for k, v in features.items():
features_json[k] = [ x.item() for x in v ]
if args.out_path is not None:
json_write( similarities, args.out_path / "similarities.json" )
json_write( features_json, args.out_path / "embeddings.json" )
# and print # and print
for filename, sim in similarities.items(): for filename, sim in similarities.items():
print(f'{filename}: {sim}') print(f'{filename}: {sim}')

View File

@ -165,11 +165,17 @@ def transcribe(
start = 0 start = 0
end = 0 end = 0
segments = [] segments = []
for segment in result["chunks"]:
info = torchaudio.info(audio)
duration = info.num_frames / info.sample_rate
for segment in result["chunks"]:
text = segment["text"] text = segment["text"]
if "timestamp" in segment: if "timestamp" in segment:
s, e = segment["timestamp"] s, e = segment["timestamp"]
if not e:
e = duration
start = min( start, s ) start = min( start, s )
end = max( end, e ) end = max( end, e )
else: else:
@ -285,12 +291,12 @@ def transcribe_batch(
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
continue continue
if group_name in ignore_groups: if dataset_name in ignore_groups:
continue continue
if only_groups and group_name not in only_groups: if only_groups and dataset_name not in only_groups:
continue continue
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/')), desc="Processing speaker"): for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/'), stride=stride, stride_offset=stride_offset), desc="Processing speaker"):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
continue continue

View File

@ -46,15 +46,11 @@ except Exception as e:
ERROR_ARCHES["bitnet"] = e ERROR_ARCHES["bitnet"] = e
pass pass
from .mixtral import MixtralModel, MixtralConfig, MixtralAttention, MixtralAttention_Adapted, MixtralModel_Adapted, load_balancing_loss_func
AVAILABLE_ARCHES.append("mixtral")
"""
try: try:
from .mixtral import MixtralModel, MixtralConfig, MixtralAttention, MixtralAttention_Adapted, MixtralModel_Adapted, load_balancing_loss_func from .mixtral import MixtralModel, MixtralConfig, MixtralAttention, MixtralAttention_Adapted, MixtralModel_Adapted, load_balancing_loss_func
AVAILABLE_ARCHES.append("mixtral") AVAILABLE_ARCHES.append("mixtral")
except Exception as e: except Exception as e:
ERROR_ARCHES["mixtral"] = e ERROR_ARCHES["mixtral"] = e
"""
try: try:
from .mamba import MambaModel, Mamba2Model, MambaConfig, Mamba2Config from .mamba import MambaModel, Mamba2Model, MambaConfig, Mamba2Config