vall-e/vall_e/emb/similar.py

375 lines
11 KiB
Python
Raw Normal View History

"""
# Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py)
# Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset
"""
import os
import json
import argparse
import torch
import torchaudio
import numpy as np
import logging
from itertools import combinations
_logger = logging.getLogger(__name__)
from tqdm.auto import tqdm
from pathlib import Path
import torchaudio.functional as F
import torchaudio.transforms as T
from ..config import cfg
2024-09-17 20:25:12 +00:00
from ..utils import truncate_json
from .g2p import encode as phonemize
from .qnt import encode as quantize, trim, convert_audio
from ..webui import init_tts
def load_audio( path ):
waveform, sr = torchaudio.load( path )
# mix channels
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# resample
waveform, sr = convert_audio(waveform, sr, cfg.sample_rate, 1), cfg.sample_rate
return waveform, sr
2024-09-17 21:26:30 +00:00
tts = None
def process(
speaker_path,
yaml,
text=False,
audio_backend="encodec",
device="cuda",
dtype="float16",
amp=False,
verbose=False,
metadata_path=None,
2024-09-17 21:26:30 +00:00
2024-09-18 02:58:44 +00:00
trim_duration=0,
min_duration=0,
max_duration=0,
storage_backend="local"
):
2024-09-17 21:26:30 +00:00
global tts
cfg.set_audio_backend(audio_backend)
artifact_extension = cfg.audio_backend_extension
cfg.inference.weight_dtype = dtype # "bfloat16"
cfg.inference.amp = amp # False
# easy way to load the model and handle encoding audio
2024-09-17 21:26:30 +00:00
if tts is None:
tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype )
queue = []
features = {}
similarities = {}
sorted_similarities = {}
2024-09-18 02:58:44 +00:00
mfcc = None
2024-09-18 02:58:44 +00:00
slop = False # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier
if storage_backend == "faiss":
slop = True
elif storage_backend == "chunkdot":
slop = True
elif storage_backend == "slop":
slop = True
2024-09-17 21:26:30 +00:00
# compute features (embeddings if quantized already, MFCC features if raw audio)
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose):
extension = filename.split(".")[-1]
filename = filename.replace(f".{extension}", "")
if text:
if extension not in artifact_extension:
raise Exception("!")
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
2024-09-18 02:58:44 +00:00
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
if 0 < min_duration and duration < min_duration:
continue
if 0 < max_duration and max_duration < duration:
continue
lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en"
if "phonemes" in artifact["metadata"]:
phn = artifact["metadata"]["phonemes"]
elif "text" in artifact["metadata"]:
txt = artifact["metadata"]["text"]
phn = phonemize( txt, language=lang )
phn = phn.replace("(en)", "")
if lang != "en":
phn = phn.replace(f"({metadata['language']})", "")
2024-09-17 21:26:30 +00:00
embedding = tts.text_embedding( phn )
else:
# treat embeddings as features, if provided quantized audio
if extension in artifact_extension:
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
2024-09-18 02:58:44 +00:00
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
if 0 < min_duration and duration < min_duration:
continue
if 0 < max_duration and max_duration < duration:
continue
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
2024-09-18 02:58:44 +00:00
if trim_duration > 0:
qnt = trim( qnt, int( cfg.dataset.frames_per_second * trim_duration ) )
2024-09-17 21:26:30 +00:00
embedding = tts.audio_embedding( qnt )
# try and extract features from the raw audio itself
else:
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' )
2024-09-18 02:58:44 +00:00
duration = wav.shape[-1] / sr
if 0 < min_duration and duration < min_duration:
continue
if 0 < max_duration and max_duration < duration:
continue
if mfcc is None:
mfcc = T.MFCC(sample_rate=cfg.sample_rate)
2024-09-17 21:26:30 +00:00
embedding = mfcc(wav.to(device))[0].t()
2024-09-18 02:58:44 +00:00
if slop:
embedding = embedding.sum(dim=0)
2024-09-17 21:26:30 +00:00
features[filename] = embedding
2024-09-18 02:58:44 +00:00
# rely on FAISS to handle storing embeddings and handling queries
# will probably explode in size fast...........
if storage_backend == "faiss":
import faiss
index = faiss.IndexFlatL2( embeddings.shape[-1] )
embeddings = torch.stack( list( features.values() ) ).cpu()
index.add( embeddings )
2024-09-17 21:26:30 +00:00
2024-09-18 02:58:44 +00:00
"""
# to-do: support just querying for list of similar to cram into JSON metadata
2024-09-17 21:26:30 +00:00
if verbose:
for filename, embedding in features.items():
2024-09-18 02:58:44 +00:00
D, I = index.search(embedding.unsqueeze(0).cpu(), k=2)
sim = list(I[0][1:])
print(f'{filename}: {sim}')
"""
2024-09-17 21:26:30 +00:00
if metadata_path is not None:
2024-09-18 02:58:44 +00:00
faiss.write_index(index, str(metadata_path.with_suffix(".faiss")))
return
"""
# to-do: actually refine this, maybe
# desu it's not super easy to install with python3.12, and it is slower than FAISS in testing............
if storage_backend == "chunkdot":
from chunkdot import cosine_similarity_top_k
embeddings = torch.stack( list( features.values() ) ).cpu().numpy()
similarities = cosine_similarity_top_k(embeddings, top_k=8, show_progress=verbose)
print(similarities)
return
"""
metadata = None
if metadata_path is not None:
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else None
keys = list(features.keys())
2024-09-18 02:58:44 +00:00
# do batch cosine similarity processing
if slop:
embeddings = torch.stack( list( features.values() ) )
sorted_similarities = {}
for index, filename in enumerate(keys):
embedding = features[filename].unsqueeze(0)
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
similarities = sorted([ ( keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
sorted_similarities[filename] = similarities
most_filename, most_score = similarities[0]
least_filename, least_score = similarities[-1]
if metadata is not None:
if filename not in metadata:
metadata[filename] = {}
metadata[filename]["similar"] = similarities
if verbose:
print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
if metadata is not None:
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( truncate_json( json.dumps( metadata ) ) )
return sorted_similarities
# an EXTREMELY naive implementation, fucking disgusting
queue = list(combinations(range(len(keys)), 2))
for key in tqdm(queue, desc="Computing similarities", disable=not verbose):
index_a, index_b = key
filename_a, filename_b = keys[index_a], keys[index_b]
swapped_key = (index_b, index_a)
if swapped_key in similarities:
similarities[key] = similarities[swapped_key]
continue
2024-09-18 02:58:44 +00:00
if slop:
embedding_a = features[filename_a]
embedding_b = features[filename_b]
similarity = torch.nn.functional.cosine_similarity(embedding_a, embedding_b, dim=0).mean().item()
else:
shortest = min( features[filename_a].shape[0], features[filename_b].shape[0] )
embedding_a = features[filename_a][:shortest, :]
embedding_b = features[filename_b][:shortest, :]
similarity = torch.nn.functional.cosine_similarity(embedding_a, embedding_b, dim=1).mean().item()
similarities[key] = similarity
2024-09-17 21:26:30 +00:00
# combinations() doesn't have swapped keys
if swapped_key not in similarities:
similarities[swapped_key] = similarity
if index_a not in sorted_similarities:
sorted_similarities[index_a] = {}
if index_b not in sorted_similarities[index_a]:
sorted_similarities[index_a][index_b] = similarity
if index_b not in sorted_similarities:
sorted_similarities[index_b] = {}
if index_a not in sorted_similarities[index_b]:
sorted_similarities[index_b][index_a] = similarity
# sort similarities scores
for key, sorted_similarity in sorted_similarities.items():
sorted_similarities[key] = sorted([ ( key, similarity ) for key, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True)
most_filename, most_score = sorted_similarities[key][0]
least_filename, least_score = sorted_similarities[key][-1]
filename = keys[key]
2024-09-17 20:25:12 +00:00
if metadata is not None:
if filename not in metadata:
metadata[filename] = {}
metadata[filename]["similar"] = sorted_similarities[key]
#if verbose:
# print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
if metadata is not None:
with open(str(metadata_path), "w", encoding="utf-8") as f:
2024-09-18 02:58:44 +00:00
f.write( truncate_json( json.dumps( metadata ) ) )
return sorted_similarities
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input-speaker", type=Path, default=None)
parser.add_argument("--use-dataset", action="store_true")
parser.add_argument("--yaml", type=Path)
parser.add_argument("--text", action="store_true")
2024-09-18 02:58:44 +00:00
parser.add_argument("--trim-duration", type=float, default=3.0)
parser.add_argument("--min-duration", type=float, default=0)
parser.add_argument("--max-duration", type=float, default=0)
parser.add_argument("--storage-backend", type=str, default="slop")
parser.add_argument("--audio-backend", type=str, default="encodec")
2024-09-18 02:58:44 +00:00
parser.add_argument("--dtype", type=str, default="float32")
parser.add_argument("--amp", action="store_true")
2024-09-17 21:26:30 +00:00
parser.add_argument("--device", type=str, default="cpu") # unironically faster
args = parser.parse_args()
if args.use_dataset:
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
def add( dir, type="training", audios=True, texts=True ):
name = str(dir)
2024-09-17 20:25:12 +00:00
name = name.replace(str(cfg.data_dir), "")
speaker_name = name
2024-09-17 20:25:12 +00:00
if "LibriTTS-R" in speaker_name:
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
process(
speaker_path=cfg.data_dir / speaker_name,
2024-09-18 02:58:44 +00:00
metadata_path=cfg.metadata_dir / f'{speaker_name}.faiss',
yaml=args.yaml,
text=args.text,
2024-09-18 02:58:44 +00:00
trim_duration=args.trim_duration,
min_duration=args.min_duration,
max_duration=args.max_duration,
storage_backend=args.storage_backend,
audio_backend=args.audio_backend,
device=args.device,
dtype=args.dtype,
amp=args.amp,
verbose=True,
)
# training
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
add( data_dir, type="training" )
# validation
for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
add( data_dir, type="validation" )
# noise
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
add( data_dir, type="noise", texts=False )
2024-09-18 02:58:44 +00:00
elif args.input_speaker:
process(
speaker_path=args.input_speaker,
yaml=args.yaml,
text=args.text,
2024-09-18 02:58:44 +00:00
trim_duration=args.trim_duration,
min_duration=args.min_duration,
max_duration=args.max_duration,
audio_backend=args.audio_backend,
device=args.device,
dtype=args.dtype,
amp=args.amp,
2024-09-18 02:58:44 +00:00
storage_backend=args.storage_backend,
verbose=True,
)
else:
raise Exception("!")
if __name__ == "__main__":
main()