diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index d21d147..942b679 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -111,9 +111,7 @@ def speaker_similarity_embedding( def batch_similar_utterances( speaker_path, yaml, - text=False, - audio_backend="encodec", device="cuda", dtype="float16", amp=False, @@ -121,13 +119,18 @@ def batch_similar_utterances( verbose=False, metadata_path=None, top_k=8, + top_p=0.5, metadata_keys=[], trim_duration=0, min_duration=0, max_duration=0, - storage_backend="slop" + audio_backend="encodec", + storage_backend="slop", + similarity_backend="resp", + + return_features=False, ): global tts @@ -151,12 +154,18 @@ def batch_similar_utterances( if not speaker_path.exists(): return + # to-do: find decent thresholds + """ + if similarity_backend != "wavlm": + top_p = float("-inf") + """ + # compute features (embeddings if quantized already, MFCC features if raw audio) for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path.name}'", disable=not verbose): extension = filename.split(".")[-1] filename = filename.replace(f".{extension}", "") - if text: + if similarity_backend == "text": if extension not in artifact_extension: raise Exception("!") @@ -183,34 +192,37 @@ def batch_similar_utterances( phn = phn.replace(f"({metadata['language']})", "") embedding = tts.text_embedding( phn ) - else: + elif similarity_backend == "resp": # treat embeddings as features, if provided quantized audio - if extension in artifact_extension: - artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()] - duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"] + if extension not in artifact_extension: + continue + artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()] + duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"] - """ - if 0 < min_duration and duration < min_duration: - continue - - if 0 < max_duration and max_duration < duration: - continue - """ + """ + if 0 < min_duration and duration < min_duration: + continue + + if 0 < max_duration and max_duration < duration: + continue + """ - qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device) + qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device) - if trim_duration > 0: - qnt = trim( qnt, int( cfg.dataset.frames_per_second * trim_duration ) ) - - embedding = tts.audio_embedding( qnt ) - # try and extract features from the raw audio itself + if trim_duration > 0: + qnt = trim( qnt, int( cfg.dataset.frames_per_second * trim_duration ) ) + + embedding = tts.audio_embedding( qnt ) + # try and extract features from the raw audio itself + else: + # qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device) + if similarity_backend == "wavlm": + embedding = speaker_similarity_embedding( f'./{speaker_path}/{filename}.{extension}' ) else: - # qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device) wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' ) - duration = wav.shape[-1] / sr - """ + duration = wav.shape[-1] / sr if 0 < min_duration and duration < min_duration: continue @@ -223,6 +235,7 @@ def batch_similar_utterances( embedding = mfcc(wav.to(device))[0].t() + dim_shape = embedding.shape[-1] if slop: embedding = embedding.sum(dim=0) @@ -254,10 +267,10 @@ def batch_similar_utterances( top_k = min( top_k, len(keys) ) if top_k == 0: - return + top_k = len(keys) # fill any missing keys with a null embedding to keep the order the same - null_embedding = torch.zeros( (1024,), device=tts.device, dtype=tts.dtype ) + null_embedding = torch.zeros( (dim_shape,), device=tts.device, dtype=tts.dtype ) embeddings = torch.stack( [ feature if feature is not None else null_embedding for feature in features.values() ] ) sorted_similarities = {} @@ -277,10 +290,12 @@ def batch_similar_utterances( similarities[index] = float("-inf") topk = torch.topk(similarities, k=top_k, largest=True, sorted=True) - similarities = [ (index, keys[index], score) for index, score in zip( topk.indices.tolist(), topk.values.tolist() ) ] + similarities = [ (index, keys[index], score) for index, score in zip( topk.indices.tolist(), topk.values.tolist() ) if score > top_p ] sorted_similarities[filename] = similarities + if return_features: + return sorted_similarities, features return sorted_similarities @@ -292,17 +307,20 @@ def main(): parser.add_argument("--use-dataset", action="store_true") parser.add_argument("--yaml", type=Path) - parser.add_argument("--text", action="store_true") + parser.add_argument("--out-path", type=Path, default=None) # dropped, because this might mess with the indices to map to """ parser.add_argument("--trim-duration", type=float, default=3.0) parser.add_argument("--min-duration", type=float, default=0) parser.add_argument("--max-duration", type=float, default=0) """ - parser.add_argument("--storage-backend", type=str, default="slop") parser.add_argument("--top-k", type=int, default=8) + parser.add_argument("--top-p", type=float, default=0.5) + parser.add_argument("--storage-backend", type=str, default="slop") + parser.add_argument("--similarity-backend", type=str, default="resp") parser.add_argument("--audio-backend", type=str, default="encodec") + parser.add_argument("--dtype", type=str, default="float16") parser.add_argument("--amp", action="store_true") parser.add_argument("--device", type=str, default="cuda") @@ -336,15 +354,17 @@ def main(): similarities = batch_similar_utterances( speaker_path=cfg.data_dir / speaker_name, yaml=args.yaml, - text=args.text, top_k=args.top_k, + top_p=args.top_p, #trim_duration=args.trim_duration, #min_duration=args.min_duration, #max_duration=args.max_duration, + audio_backend=args.audio_backend, storage_backend=args.storage_backend, + similarity_backend=args.similarity_backend, + metadata_keys=metadata_keys, - audio_backend=args.audio_backend, device=args.device, dtype=args.dtype, amp=args.amp, @@ -381,25 +401,36 @@ def main(): add( data_dir, type="noise", texts=False ) elif args.input_speaker: - similarities = batch_similar_utterances( + similarities, features = batch_similar_utterances( speaker_path=args.input_speaker, yaml=args.yaml, - text=args.text, top_k=args.top_k, - + top_p=args.top_p, + #trim_duration=args.trim_duration, #min_duration=args.min_duration, #max_duration=args.max_duration, - audio_backend=args.audio_backend, device=args.device, dtype=args.dtype, amp=args.amp, + audio_backend=args.audio_backend, storage_backend=args.storage_backend, + similarity_backend=args.similarity_backend, + verbose=True, + return_features=True, ) + features_json = {} + for k, v in features.items(): + features_json[k] = [ x.item() for x in v ] + + if args.out_path is not None: + json_write( similarities, args.out_path / "similarities.json" ) + json_write( features_json, args.out_path / "embeddings.json" ) + # and print for filename, sim in similarities.items(): print(f'{filename}: {sim}') diff --git a/vall_e/emb/transcribe.py b/vall_e/emb/transcribe.py index 774a479..6f3bbe2 100644 --- a/vall_e/emb/transcribe.py +++ b/vall_e/emb/transcribe.py @@ -165,11 +165,17 @@ def transcribe( start = 0 end = 0 segments = [] - for segment in result["chunks"]: + + info = torchaudio.info(audio) + duration = info.num_frames / info.sample_rate + + for segment in result["chunks"]: text = segment["text"] if "timestamp" in segment: s, e = segment["timestamp"] + if not e: + e = duration start = min( start, s ) end = max( end, e ) else: @@ -285,12 +291,12 @@ def transcribe_batch( if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): continue - if group_name in ignore_groups: + if dataset_name in ignore_groups: continue - if only_groups and group_name not in only_groups: + if only_groups and dataset_name not in only_groups: continue - for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/')), desc="Processing speaker"): + for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/'), stride=stride, stride_offset=stride_offset), desc="Processing speaker"): if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): continue diff --git a/vall_e/models/arch/__init__.py b/vall_e/models/arch/__init__.py index 27a75c3..4055176 100755 --- a/vall_e/models/arch/__init__.py +++ b/vall_e/models/arch/__init__.py @@ -46,15 +46,11 @@ except Exception as e: ERROR_ARCHES["bitnet"] = e pass -from .mixtral import MixtralModel, MixtralConfig, MixtralAttention, MixtralAttention_Adapted, MixtralModel_Adapted, load_balancing_loss_func -AVAILABLE_ARCHES.append("mixtral") -""" try: from .mixtral import MixtralModel, MixtralConfig, MixtralAttention, MixtralAttention_Adapted, MixtralModel_Adapted, load_balancing_loss_func AVAILABLE_ARCHES.append("mixtral") except Exception as e: ERROR_ARCHES["mixtral"] = e -""" try: from .mamba import MambaModel, Mamba2Model, MambaConfig, Mamba2Config