543 lines
15 KiB
Python
543 lines
15 KiB
Python
"""
|
|
# Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py)
|
|
# Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset
|
|
"""
|
|
|
|
import os
|
|
import argparse
|
|
import torch
|
|
import torchaudio
|
|
import numpy as np
|
|
import logging
|
|
|
|
from itertools import combinations
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
from tqdm.auto import tqdm
|
|
from pathlib import Path
|
|
from functools import cache
|
|
|
|
import torchaudio.functional as F
|
|
import torchaudio.transforms as T
|
|
|
|
from ..config import cfg
|
|
from ..data import _load_artifact
|
|
from ..utils import truncate_json, coerce_dtype
|
|
from ..utils.io import json_read, json_write
|
|
|
|
from .g2p import encode as phonemize
|
|
from .qnt import encode as quantize, trim, convert_audio
|
|
|
|
from ..models import download_model
|
|
|
|
from ..webui import init_tts
|
|
|
|
def load_audio( path, target_sr=None ):
|
|
waveform, sr = torchaudio.load( path )
|
|
# mix channels
|
|
if waveform.shape[0] > 1:
|
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
if target_sr is None:
|
|
target_sr = cfg.sample_rate
|
|
# resample
|
|
waveform, sr = convert_audio(waveform, sr, target_sr, 1), target_sr
|
|
|
|
return waveform, sr
|
|
|
|
tts = None
|
|
|
|
# this is for computing SIM-O, but can probably technically be used for scoring similar utterances
|
|
@cache
|
|
def _load_sim_model(device="cuda", dtype="float16", model_name='microsoft/wavlm-large', finetune=False):
|
|
from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL
|
|
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large')
|
|
|
|
if finetune:
|
|
finetune_path = Path("./data/models/wavlm_large_finetune.pth")
|
|
if not finetune_path.exists():
|
|
download_model(finetune_path)
|
|
|
|
state_dict = torch.load( finetune_path )
|
|
state_dict = state_dict['model']
|
|
del state_dict['loss_calculator.projection.weight']
|
|
model.load_state_dict( state_dict )
|
|
|
|
model = model.to(device=device, dtype=coerce_dtype(dtype))
|
|
model = model.eval()
|
|
|
|
return model, None
|
|
|
|
"""
|
|
logging.getLogger('s3prl').setLevel(logging.DEBUG)
|
|
logging.getLogger('speechbrain').setLevel(logging.DEBUG)
|
|
from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector
|
|
model = WavLMForXVector.from_pretrained(model_name)
|
|
finetune_path = Path("./data/models/wavlm_large_finetune.pth")
|
|
if finetune_path.exists():
|
|
state_dict = torch.load( finetune_path )
|
|
model.load_state_dict( state_dict['model'] )
|
|
model = model.to(device=device, dtype=coerce_dtype(dtype))
|
|
model = model.eval()
|
|
|
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
|
|
|
|
return model, feature_extractor
|
|
"""
|
|
|
|
@torch.no_grad()
|
|
def speaker_similarity_embedding(
|
|
audio,
|
|
**model_kwargs,
|
|
):
|
|
model_kwargs["finetune"] = True
|
|
device = model_kwargs.get("device", "cuda")
|
|
dtype = model_kwargs.get("dtype", "float16")
|
|
|
|
model, feature_extractor = _load_sim_model(**model_kwargs)
|
|
|
|
if isinstance(audio, str) or isinstance(audio, Path):
|
|
audio = load_audio(audio, 16000)
|
|
|
|
audio, sr = audio
|
|
embeddings = model(audio.to(device=device, dtype=coerce_dtype(dtype)))
|
|
"""
|
|
features = feature_extractor(audio, return_tensors="pt", sampling_rate=sr)
|
|
features = features.input_values.squeeze(0).to(dtype=coerce_dtype(dtype), device=device)
|
|
|
|
output = model(input_values=features)
|
|
embeddings = output.embeddings
|
|
embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
|
|
"""
|
|
return embeddings
|
|
|
|
def batch_similar_utterances(
|
|
speaker_path,
|
|
yaml,
|
|
|
|
device="cuda",
|
|
dtype="float16",
|
|
amp=False,
|
|
|
|
verbose=False,
|
|
metadata_path=None,
|
|
top_k=8,
|
|
top_p=0.5,
|
|
metadata_keys=[],
|
|
|
|
trim_duration=0,
|
|
min_duration=0,
|
|
max_duration=0,
|
|
|
|
audio_backend="encodec",
|
|
storage_backend="slop",
|
|
similarity_backend="resp",
|
|
|
|
return_features=False,
|
|
):
|
|
global tts
|
|
|
|
cfg.set_audio_backend(audio_backend)
|
|
artifact_extension = cfg.audio_backend_extension
|
|
|
|
cfg.inference.weight_dtype = dtype # "bfloat16"
|
|
cfg.inference.amp = amp # False
|
|
|
|
# easy way to load the model and handle encoding audio
|
|
if tts is None:
|
|
tts = init_tts( config=yaml, restart=False, device=device, dtype=dtype )
|
|
|
|
features = { key: None for key in metadata_keys }
|
|
|
|
mfcc = None
|
|
|
|
simplified_metadata = True # aims to slim down the raw data in the JSON to store
|
|
slop = True # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier
|
|
|
|
if not speaker_path.exists():
|
|
return
|
|
|
|
# to-do: find decent thresholds
|
|
"""
|
|
if similarity_backend != "wavlm":
|
|
top_p = float("-inf")
|
|
"""
|
|
|
|
# compute features (embeddings if quantized already, MFCC features if raw audio)
|
|
dim_shape = 1024
|
|
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path.name}'", disable=not verbose):
|
|
extension = filename.split(".")[-1]
|
|
filename = filename.replace(f".{extension}", "")
|
|
|
|
if filename not in features:
|
|
continue
|
|
|
|
if similarity_backend == "text":
|
|
if extension not in artifact_extension:
|
|
raise Exception("!")
|
|
|
|
_, metadata = _load_artifact(f'./{speaker_path}/{filename}.{extension}', return_metadata=True)
|
|
|
|
"""
|
|
duration = metadata["original_length"] / metadata["sample_rate"]
|
|
if 0 < min_duration and duration < min_duration:
|
|
continue
|
|
|
|
if 0 < max_duration and max_duration < duration:
|
|
continue
|
|
"""
|
|
|
|
lang = metadata["language"] if "language" in metadata["language"] else "en"
|
|
if "phonemes" in metadata:
|
|
phn = metadata["phonemes"]
|
|
elif "text" in metadata:
|
|
txt = metadata["text"]
|
|
phn = phonemize( txt, language=lang )
|
|
|
|
phn = phn.replace("(en)", "")
|
|
if lang != "en":
|
|
phn = phn.replace(f"({metadata['language']})", "")
|
|
|
|
embedding = tts.text_embedding( phn )
|
|
elif similarity_backend == "resp":
|
|
# treat embeddings as features, if provided quantized audio
|
|
if extension not in artifact_extension:
|
|
continue
|
|
|
|
qnt, metadata = _load_artifact(f'./{speaker_path}/{filename}.{extension}', return_metadata=True)
|
|
|
|
"""
|
|
duration = metadata["original_length"] / metadata["sample_rate"]
|
|
|
|
if 0 < min_duration and duration < min_duration:
|
|
continue
|
|
|
|
if 0 < max_duration and max_duration < duration:
|
|
continue
|
|
"""
|
|
|
|
if trim_duration > 0:
|
|
qnt = trim( qnt, int( cfg.dataset.frames_per_second * trim_duration ) )
|
|
|
|
qnt = qnt.to(device)
|
|
|
|
embedding = tts.audio_embedding( qnt )
|
|
# try and extract features from the raw audio itself
|
|
else:
|
|
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
|
|
if similarity_backend == "wavlm":
|
|
embedding = speaker_similarity_embedding( f'./{speaker_path}/{filename}.{extension}' )
|
|
else:
|
|
wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' )
|
|
|
|
"""
|
|
duration = wav.shape[-1] / sr
|
|
if 0 < min_duration and duration < min_duration:
|
|
continue
|
|
|
|
if 0 < max_duration and max_duration < duration:
|
|
continue
|
|
"""
|
|
|
|
if mfcc is None:
|
|
mfcc = T.MFCC(sample_rate=cfg.sample_rate)
|
|
|
|
embedding = mfcc(wav.to(device))[0].t()
|
|
|
|
dim_shape = embedding.shape[-1]
|
|
if slop:
|
|
embedding = embedding.sum(dim=0)
|
|
|
|
features[filename] = embedding
|
|
|
|
# rely on FAISS to handle storing embeddings and handling queries
|
|
# will probably explode in size fast...........
|
|
if storage_backend == "faiss":
|
|
import faiss
|
|
|
|
index = faiss.IndexFlatL2( embeddings.shape[-1] )
|
|
embeddings = torch.stack( list( features.values() ) ).cpu()
|
|
index.add( embeddings )
|
|
|
|
"""
|
|
# to-do: support just querying for list of similar to cram into JSON metadata
|
|
if verbose:
|
|
for filename, embedding in features.items():
|
|
D, I = index.search(embedding.unsqueeze(0).cpu(), k=top_k+1)
|
|
sim = list(I[0][1:])
|
|
print(f'{filename}: {sim}')
|
|
"""
|
|
|
|
return index
|
|
|
|
# do batch cosine similarity processing
|
|
|
|
keys = list(features.keys())
|
|
top_k = min( top_k, len(keys) )
|
|
|
|
if top_k == 0:
|
|
top_k = len(keys)
|
|
|
|
if len(keys) == 0:
|
|
return None
|
|
|
|
# fill any missing keys with a null embedding to keep the order the same
|
|
null_embedding = torch.zeros( (dim_shape,), device=tts.device, dtype=tts.dtype )
|
|
embeddings = torch.stack( [ feature if feature is not None else null_embedding for feature in features.values() ] )
|
|
sorted_similarities = {}
|
|
|
|
|
|
for index, filename in tqdm(enumerate(keys), total=len(keys), desc=f"Computing similarities: {speaker_path.name}", disable=not verbose):
|
|
if features[filename] is None:
|
|
continue
|
|
|
|
embedding = features[filename].unsqueeze(0)
|
|
|
|
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1)
|
|
|
|
# sorting is slow, don't bother
|
|
#sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
|
|
|
|
# set current index to -inf
|
|
similarities[index] = float("-inf")
|
|
|
|
topk = torch.topk(similarities, k=top_k, largest=True, sorted=True)
|
|
similarities = [ (index, keys[index], score) for index, score in zip( topk.indices.tolist(), topk.values.tolist() ) if score > top_p ]
|
|
|
|
sorted_similarities[filename] = similarities
|
|
|
|
if return_features:
|
|
return sorted_similarities, features
|
|
|
|
return sorted_similarities
|
|
|
|
"""
|
|
# (Attempts to) group speakers based on top-k cosine similarities, by pooling together similar utterances together
|
|
# It sort of works, but the WavLM finetuned for speaker similarities leaves some false positives without decent threshold values
|
|
"""
|
|
def sort_similarities(
|
|
path,
|
|
num_speakers,
|
|
out_path=None,
|
|
threshold=0.8,
|
|
orphan_threshold=0.6,
|
|
):
|
|
from sklearn.cluster import KMeans
|
|
|
|
folders = [ "1", "2", "3", "4", "5", "6-7", "8", "9", "10", "11", "12", "14", "15" ]
|
|
embeddings = json_read(path / "0" / "embeddings.json")
|
|
|
|
for filename, embedding in embeddings.items():
|
|
embeddings[filename] = np.array(embedding)
|
|
|
|
embeddings_array = np.stack( list( embeddings.values() ) )
|
|
kmeans = KMeans(n_clusters=num_speakers).fit(embeddings_array)
|
|
|
|
"""
|
|
if not out_path:
|
|
out_path = path.parent / "speakers.json"
|
|
|
|
orphans = []
|
|
speakers = []
|
|
|
|
for filename, similarities in metadata.items():
|
|
target = False
|
|
|
|
# find any existing buckets
|
|
for i, pool in enumerate(speakers):
|
|
for (idx, name, score) in similarities:
|
|
if score and score < threshold:
|
|
continue
|
|
if name in pool:
|
|
target = i
|
|
break
|
|
|
|
if target != False:
|
|
break
|
|
# not found, create new bucket
|
|
if target == False:
|
|
pool = [ name for (idx, name, score) in similarities if (not score or score > threshold) ]
|
|
if filename not in pool:
|
|
pool.append(filename)
|
|
|
|
# orphan, check later
|
|
if len(pool) == 1:
|
|
orphans += pool
|
|
else:
|
|
speakers.append(pool)
|
|
continue
|
|
|
|
# insert entries into pool
|
|
if filename not in speakers[target]:
|
|
speakers[target].append(filename)
|
|
|
|
for (idx, name, score) in similarities:
|
|
if score and score < threshold:
|
|
continue
|
|
if name not in speakers[target]:
|
|
speakers[target].append(name)
|
|
|
|
# shove orphans to best scoring pool
|
|
for filename in orphans:
|
|
target = False
|
|
for (idx, name, score) in metadata[filename]:
|
|
if score and score < orphan_threshold:
|
|
continue
|
|
for i, pool in enumerate(speakers):
|
|
if name in pool:
|
|
target = i
|
|
break
|
|
if target != False:
|
|
continue
|
|
|
|
if target == False:
|
|
continue
|
|
|
|
speakers[target].append(filename)
|
|
"""
|
|
|
|
json_write( speakers, out_path )
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--input-speaker", type=Path, default=None)
|
|
parser.add_argument("--input-voice", type=str, default=None)
|
|
parser.add_argument("--use-dataset", action="store_true")
|
|
|
|
parser.add_argument("--yaml", type=Path)
|
|
parser.add_argument("--out-path", type=Path, default=None)
|
|
# dropped, because this might mess with the indices to map to
|
|
"""
|
|
parser.add_argument("--trim-duration", type=float, default=3.0)
|
|
parser.add_argument("--min-duration", type=float, default=0)
|
|
parser.add_argument("--max-duration", type=float, default=0)
|
|
"""
|
|
parser.add_argument("--top-k", type=int, default=8)
|
|
parser.add_argument("--top-p", type=float, default=0.5)
|
|
|
|
parser.add_argument("--storage-backend", type=str, default="slop")
|
|
parser.add_argument("--similarity-backend", type=str, default="resp")
|
|
parser.add_argument("--audio-backend", type=str, default="encodec")
|
|
|
|
parser.add_argument("--dtype", type=str, default="float16")
|
|
parser.add_argument("--amp", action="store_true")
|
|
parser.add_argument("--device", type=str, default="cuda")
|
|
|
|
args = parser.parse_args()
|
|
|
|
args.skip_existing = True #
|
|
|
|
if args.use_dataset:
|
|
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
def add( dir, type="training", audios=True, texts=True ):
|
|
name = str(dir)
|
|
name = name.replace(str(cfg.data_dir), "")
|
|
speaker_name = name
|
|
"""
|
|
if "LibriTTS-R" in speaker_name:
|
|
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
|
"""
|
|
|
|
if args.input_voice and speaker_name != args.input_voice:
|
|
return
|
|
|
|
metadata_path = cfg.metadata_dir / f'{speaker_name}.json'
|
|
metadata = json_read( metadata_path, default={} )
|
|
metadata_keys = list(metadata.keys()) if metadata else []
|
|
|
|
if args.skip_existing and metadata_keys and "similar" in metadata[metadata_keys[-1]]:
|
|
return
|
|
|
|
try:
|
|
similarities = batch_similar_utterances(
|
|
speaker_path=cfg.data_dir / speaker_name,
|
|
yaml=args.yaml,
|
|
top_k=args.top_k,
|
|
top_p=args.top_p,
|
|
#trim_duration=args.trim_duration,
|
|
#min_duration=args.min_duration,
|
|
#max_duration=args.max_duration,
|
|
audio_backend=args.audio_backend,
|
|
storage_backend=args.storage_backend,
|
|
similarity_backend=args.similarity_backend,
|
|
|
|
metadata_keys=metadata_keys,
|
|
|
|
device=args.device,
|
|
dtype=args.dtype,
|
|
amp=args.amp,
|
|
|
|
verbose=True,
|
|
)
|
|
except Exception as e:
|
|
similarities = None
|
|
|
|
if not similarities:
|
|
return
|
|
|
|
if args.storage_backend == "faiss":
|
|
faiss.write_index(similarities, str(metadata_path.with_suffix(".faiss")))
|
|
return
|
|
|
|
for filename, similar in similarities.items():
|
|
if filename not in metadata:
|
|
metadata[filename] = {}
|
|
|
|
# overkill but i'm very paranoid about mismatching indices
|
|
metadata[filename]["similar"] = [ metadata_keys.index(s[1]) for s in similar ]
|
|
|
|
json_write( metadata, metadata_path )
|
|
|
|
# training
|
|
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
|
add( data_dir, type="training" )
|
|
|
|
# validation
|
|
for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
|
|
add( data_dir, type="validation" )
|
|
|
|
# noise
|
|
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
|
|
add( data_dir, type="noise", texts=False )
|
|
|
|
elif args.input_speaker:
|
|
similarities, features = batch_similar_utterances(
|
|
speaker_path=args.input_speaker,
|
|
yaml=args.yaml,
|
|
top_k=args.top_k,
|
|
top_p=args.top_p,
|
|
|
|
#trim_duration=args.trim_duration,
|
|
#min_duration=args.min_duration,
|
|
#max_duration=args.max_duration,
|
|
|
|
device=args.device,
|
|
dtype=args.dtype,
|
|
amp=args.amp,
|
|
|
|
audio_backend=args.audio_backend,
|
|
storage_backend=args.storage_backend,
|
|
similarity_backend=args.similarity_backend,
|
|
|
|
verbose=True,
|
|
return_features=True,
|
|
)
|
|
|
|
if args.out_path is not None:
|
|
features_json = {}
|
|
for k, v in features.items():
|
|
features_json[k] = [ x.item() for x in v ]
|
|
|
|
json_write( similarities, args.out_path / "similarities.json" )
|
|
json_write( features_json, args.out_path / "embeddings.json" )
|
|
else:
|
|
# and print
|
|
for filename, sim in similarities.items():
|
|
print(f'{filename}: {sim}')
|
|
else:
|
|
raise Exception("!")
|
|
|
|
if __name__ == "__main__":
|
|
main() |