diff --git a/vall_e/data.py b/vall_e/data.py index bc687b9..41e7e73 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -874,7 +874,7 @@ class Dataset(_Dataset): return path, text, resps - def sample_prompts(self, spkr_name, ignore, should_trim=True): + def sample_prompts(self, spkr_name, ignore, should_trim=True, reference=None): if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0: return None @@ -895,15 +895,11 @@ class Dataset(_Dataset): prom_length = 0 trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0 + # to-do: if reference is not None, find the closest utterances to the reference for _ in range(cfg.dataset.max_prompts): path = random.choice(choices) if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) - - if "audio" not in cfg.hdf5[key]: - _logger.warning(f'MISSING AUDIO: {key}') - continue - qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) else: qnt = _load_quants(path, return_metadata=False) @@ -1015,7 +1011,7 @@ class Dataset(_Dataset): # Base TTS ( => ) if task == "tts": - proms = self.sample_prompts(spkr_name, ignore=path) + proms = self.sample_prompts(spkr_name, ignore=path, reference=resps) if cfg.dataset.inject_noise_in_prom: # sample random noise diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py new file mode 100644 index 0000000..01d7722 --- /dev/null +++ b/vall_e/emb/similar.py @@ -0,0 +1,170 @@ +""" +# Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py) +# Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset +""" + +import os +import json +import argparse +import torch +import torchaudio +import numpy as np +import logging + +_logger = logging.getLogger(__name__) + +from tqdm.auto import tqdm +from pathlib import Path + +import torchaudio.functional as F +import torchaudio.transforms as T + +from ..config import cfg + +# need to validate if this is safe to import before modifying the config +from .g2p import encode as phonemize +from .qnt import encode as quantize, trim, convert_audio + +from ..webui import init_tts + +def load_audio( path ): + waveform, sr = torchaudio.load( path ) + # mix channels + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + # resample + waveform, sr = convert_audio(waveform, sr, cfg.sample_rate, 1), cfg.sample_rate + + return waveform, sr + +def process( + input_speaker, + yaml, + + audio_backend="encodec", + output_dataset="training", + raise_exceptions=False, + stride=0, + stride_offset=0, + slice="auto", + + device="cuda", + dtype="float16", + amp=False, + + verbose=False, +): + cfg.set_audio_backend(audio_backend) + audio_extension = cfg.audio_backend_extension + + cfg.inference.weight_dtype = dtype # "bfloat16" + cfg.inference.amp = amp # False + + # easy way to load the model and handle encoding audio + tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype ) + + queue = [] + features = {} + similarities = {} + sorted_similarities = {} + + mfcc = T.MFCC(sample_rate=cfg.sample_rate) + + # compute features (embeddings if quantized already, MFCC features if raw audio) + for filename in tqdm(os.listdir(f'./{input_speaker}/'), desc="Encoding...", disable=not verbose): + extension = filename.split(".")[-1] + + # treat embeddings as features, if provided quantized audio + if extension in audio_extension: + artifact = np.load(f'./{input_speaker}/{filename}', allow_pickle=True)[()] + qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device) + qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) ) + + features[filename] = tts.audio_embedding( qnt ) + # try and extract features from the raw audio itself + else: + # qnt = tts.encode_audio(f'./{input_speaker}/{filename}', trim_length=3.0).to(device) + wav, sr = load_audio( f'./{input_speaker}/{filename}' ) + features[filename] = mfcc(wav.to(device))[0].t() + + # calculate pairs, flattened because it makes tqdm nicer + for filename_a, embedding_a in features.items(): + for filename_b, embedding_b in features.items(): + if filename_a == filename_b: + continue + + key = f'{filename_a}:{filename_b}' + + if key in queue: + continue + + queue.append(key) + + # compute similarities for every utternace + for key in tqdm(queue, desc="Computing similarities", disable=not verbose): + filename_a, filename_b = key.split(":") + swapped_key = f'{filename_b}:{filename_a}' + if swapped_key in similarities: + similarities[key] = similarities[swapped_key] + continue + + shortest = min( features[filename_a].shape[0], features[filename_b].shape[0] ) + similarities[key] = torch.nn.functional.cosine_similarity(features[filename_a][:shortest, :], features[filename_b][:shortest, :], dim=1).mean().item() + + # ??? + for key, similarity in similarities.items(): + filename_a, filename_b = key.split(":") + + if filename_a not in sorted_similarities: + sorted_similarities[filename_a] = {} + if filename_b not in sorted_similarities[filename_a]: + sorted_similarities[filename_a][filename_b] = similarity + + if filename_b not in sorted_similarities: + sorted_similarities[filename_b] = {} + if filename_a not in sorted_similarities[filename_b]: + sorted_similarities[filename_b][filename_a] = similarity + + # sort similarities scores + for key, sorted_similarity in sorted_similarities.items(): + sorted_similarities[key] = sorted([ ( filename, similarity ) for filename, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True) + + most_filename, most_score = sorted_similarities[key][0] + least_filename, least_score = sorted_similarities[key][-1] + + if verbose: + print( f'{key}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' ) + + # to-do: store this somewhere + + return sorted_similarities + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument("--input-speaker", type=Path) + parser.add_argument("--yaml", type=Path) + parser.add_argument("--audio-backend", type=str, default="encodec") + parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--amp", action="store_true") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--raise-exceptions", action="store_true") + + args = parser.parse_args() + + process( + input_speaker=args.input_speaker, + yaml=args.yaml, + + audio_backend=args.audio_backend, + raise_exceptions=args.raise_exceptions, + + device=args.device, + dtype=args.dtype, + amp=args.amp, + + verbose=True, + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/vall_e/inference.py b/vall_e/inference.py index 66ef906..752b209 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -94,6 +94,7 @@ class TTS(): id = symmap[language] return torch.tensor([ id ]) + # to-do: trim before quantizing, instead of after def encode_audio( self, paths, trim_length=0.0 ): # already a tensor, return it if isinstance( paths, Tensor ): @@ -122,6 +123,29 @@ class TTS(): return res + @torch.inference_mode() + def audio_embedding( self, input, prom=False ): + model = None + + for name, engine in self.engines.items(): + model = engine.module + break + + # im really not sure which way is the better way, since the proms_emb and resps_emb have different properties....... + if prom: + return model.proms_emb( + input, + quant_level=input.shape[-1] - 1, + offset=0, + sums=True, + ) + return sum([ model.resps_emb( + input[:, :l+1], + offset = 0 if l == 0 else 1, # or maybe set to 1 + quant_level = l, + sums = False + ) for l in range( input.shape[-1] - 1 ) ]) + @torch.inference_mode() def inference( self,