From 8f41d1b3241322d755c1109bda0db582363f2c99 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 17 Sep 2024 16:26:30 -0500 Subject: [PATCH] more tweaks --- vall_e/emb/similar.py | 56 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index 04fae99..d3afea0 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -40,6 +40,8 @@ def load_audio( path ): return waveform, sr +tts = None + def process( speaker_path, yaml, @@ -52,7 +54,12 @@ def process( verbose=False, metadata_path=None, + + maximum_duration=0, + #use_faiss=True, ): + global tts + cfg.set_audio_backend(audio_backend) artifact_extension = cfg.audio_backend_extension @@ -60,7 +67,8 @@ def process( cfg.inference.amp = amp # False # easy way to load the model and handle encoding audio - tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype ) + if tts is None: + tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype ) queue = [] features = {} @@ -69,6 +77,13 @@ def process( mfcc = T.MFCC(sample_rate=cfg.sample_rate) + """ + # too slow + if use_faiss: + import faiss + index = None + """ + # compute features (embeddings if quantized already, MFCC features if raw audio) for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose): extension = filename.split(".")[-1] @@ -91,20 +106,40 @@ def process( if lang != "en": phn = phn.replace(f"({metadata['language']})", "") - features[filename] = tts.text_embedding( phn ) + embedding = tts.text_embedding( phn ) else: # treat embeddings as features, if provided quantized audio if extension in artifact_extension: artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()] qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device) - qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) ) + if maximum_duration > 0: + qnt = trim( qnt, int( cfg.dataset.frames_per_second * maximum_duration ) ) - features[filename] = tts.audio_embedding( qnt ) + embedding = tts.audio_embedding( qnt ) # try and extract features from the raw audio itself else: # qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device) wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' ) - features[filename] = mfcc(wav.to(device))[0].t() + embedding = mfcc(wav.to(device))[0].t() + + features[filename] = embedding + + """ + if use_faiss: + if index is None: + shape = embedding.shape + index = faiss.IndexFlatL2(shape[1]) + + index.add(embedding.cpu()) + + if verbose: + for filename, embedding in features.items(): + D, I = index.search(embedding.cpu(), k=3) + # print(f'{filename}: {I[1:]}') + + if metadata_path is not None: + index.save(metadata_path) + """ keys = list(features.keys()) key_range = range(len(keys)) @@ -126,6 +161,10 @@ def process( similarities[key] = similarity + # combinations() doesn't have swapped keys + if swapped_key not in similarities: + similarities[swapped_key] = similarity + if index_a not in sorted_similarities: sorted_similarities[index_a] = {} if index_b not in sorted_similarities[index_a]: @@ -176,11 +215,12 @@ def main(): parser.add_argument("--yaml", type=Path) parser.add_argument("--text", action="store_true") + parser.add_argument("--maximum-duration", type=float, default=3.0) parser.add_argument("--audio-backend", type=str, default="encodec") - parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--dtype", type=str, default="float16") parser.add_argument("--amp", action="store_true") - parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--device", type=str, default="cpu") # unironically faster args = parser.parse_args() @@ -199,6 +239,7 @@ def main(): metadata_path=cfg.metadata_dir / f'{speaker_name}.json', yaml=args.yaml, text=args.text, + maximum_duration=args.maximum_duration, audio_backend=args.audio_backend, device=args.device, @@ -224,6 +265,7 @@ def main(): speaker_path=args.input_speaker, yaml=args.yaml, text=args.text, + maximum_duration=args.maximum_duration, audio_backend=args.audio_backend, device=args.device,