From 4f3c7a37c82b8fdc289b346cbf2c7f0e12c0da29 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 10 Sep 2024 16:45:59 -0500 Subject: [PATCH] also do text similarities (dont know what use I'll have for this) --- vall_e/emb/similar.py | 53 +++++++++++++++++++++++++++---------------- vall_e/inference.py | 16 +++++++++++++ 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index 01d7722..150140c 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -40,14 +40,9 @@ def load_audio( path ): def process( input_speaker, yaml, + text=False, audio_backend="encodec", - output_dataset="training", - raise_exceptions=False, - stride=0, - stride_offset=0, - slice="auto", - device="cuda", dtype="float16", amp=False, @@ -55,7 +50,7 @@ def process( verbose=False, ): cfg.set_audio_backend(audio_backend) - audio_extension = cfg.audio_backend_extension + artifact_extension = cfg.audio_backend_extension cfg.inference.weight_dtype = dtype # "bfloat16" cfg.inference.amp = amp # False @@ -74,18 +69,38 @@ def process( for filename in tqdm(os.listdir(f'./{input_speaker}/'), desc="Encoding...", disable=not verbose): extension = filename.split(".")[-1] - # treat embeddings as features, if provided quantized audio - if extension in audio_extension: + + if text: + if extension not in artifact_extension: + raise Exception("!") + artifact = np.load(f'./{input_speaker}/{filename}', allow_pickle=True)[()] - qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device) - qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) ) + + lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en" + if "phonemes" in artifact["metadata"]: + phn = artifact["metadata"]["phonemes"] + elif "text" in artifact["metadata"]: + txt = artifact["metadata"]["text"] + phn = phonemize( txt, language=lang ) - features[filename] = tts.audio_embedding( qnt ) - # try and extract features from the raw audio itself + phn = phn.replace("(en)", "") + if lang != "en": + phn = phn.replace(f"({metadata['language']})", "") + + features[filename] = tts.text_embedding( phn ) else: - # qnt = tts.encode_audio(f'./{input_speaker}/{filename}', trim_length=3.0).to(device) - wav, sr = load_audio( f'./{input_speaker}/{filename}' ) - features[filename] = mfcc(wav.to(device))[0].t() + # treat embeddings as features, if provided quantized audio + if extension in artifact_extension: + artifact = np.load(f'./{input_speaker}/{filename}', allow_pickle=True)[()] + qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device) + qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) ) + + features[filename] = tts.audio_embedding( qnt ) + # try and extract features from the raw audio itself + else: + # qnt = tts.encode_audio(f'./{input_speaker}/{filename}', trim_length=3.0).to(device) + wav, sr = load_audio( f'./{input_speaker}/{filename}' ) + features[filename] = mfcc(wav.to(device))[0].t() # calculate pairs, flattened because it makes tqdm nicer for filename_a, embedding_a in features.items(): @@ -144,21 +159,21 @@ def main(): parser.add_argument("--input-speaker", type=Path) parser.add_argument("--yaml", type=Path) + parser.add_argument("--text", action="store_true") + parser.add_argument("--audio-backend", type=str, default="encodec") parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--amp", action="store_true") parser.add_argument("--device", type=str, default="cuda") - parser.add_argument("--raise-exceptions", action="store_true") args = parser.parse_args() process( input_speaker=args.input_speaker, yaml=args.yaml, + text=args.text, audio_backend=args.audio_backend, - raise_exceptions=args.raise_exceptions, - device=args.device, dtype=args.dtype, amp=args.amp, diff --git a/vall_e/inference.py b/vall_e/inference.py index 752b209..bd05ba1 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -123,6 +123,22 @@ class TTS(): return res + @torch.inference_mode() + def text_embedding( self, input, prom=False ): + model = None + + for name, engine in self.engines.items(): + model = engine.module + break + + if isinstance( input, str ): + input = cfg.tokenizer.encode(input) + + if isinstance( input, list ): + input = torch.tensor( input, dtype=torch.uint8, device=self.device ) + + return model.text_emb( input ) + @torch.inference_mode() def audio_embedding( self, input, prom=False ): model = None