solved my problem

This commit is contained in:
mrq 2024-09-17 21:58:44 -05:00
parent 8f41d1b324
commit be22b65300

View File

@ -24,7 +24,6 @@ import torchaudio.transforms as T
from ..config import cfg from ..config import cfg
from ..utils import truncate_json from ..utils import truncate_json
# need to validate if this is safe to import before modifying the config
from .g2p import encode as phonemize from .g2p import encode as phonemize
from .qnt import encode as quantize, trim, convert_audio from .qnt import encode as quantize, trim, convert_audio
@ -55,8 +54,11 @@ def process(
verbose=False, verbose=False,
metadata_path=None, metadata_path=None,
maximum_duration=0, trim_duration=0,
#use_faiss=True, min_duration=0,
max_duration=0,
storage_backend="local"
): ):
global tts global tts
@ -75,14 +77,15 @@ def process(
similarities = {} similarities = {}
sorted_similarities = {} sorted_similarities = {}
mfcc = T.MFCC(sample_rate=cfg.sample_rate) mfcc = None
""" slop = False # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier
# too slow if storage_backend == "faiss":
if use_faiss: slop = True
import faiss elif storage_backend == "chunkdot":
index = None slop = True
""" elif storage_backend == "slop":
slop = True
# compute features (embeddings if quantized already, MFCC features if raw audio) # compute features (embeddings if quantized already, MFCC features if raw audio)
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose): for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose):
@ -94,6 +97,13 @@ def process(
raise Exception("!") raise Exception("!")
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()] artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
if 0 < min_duration and duration < min_duration:
continue
if 0 < max_duration and max_duration < duration:
continue
lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en" lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en"
if "phonemes" in artifact["metadata"]: if "phonemes" in artifact["metadata"]:
@ -111,42 +121,116 @@ def process(
# treat embeddings as features, if provided quantized audio # treat embeddings as features, if provided quantized audio
if extension in artifact_extension: if extension in artifact_extension:
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()] artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
if 0 < min_duration and duration < min_duration:
continue
if 0 < max_duration and max_duration < duration:
continue
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device) qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
if maximum_duration > 0:
qnt = trim( qnt, int( cfg.dataset.frames_per_second * maximum_duration ) ) if trim_duration > 0:
qnt = trim( qnt, int( cfg.dataset.frames_per_second * trim_duration ) )
embedding = tts.audio_embedding( qnt ) embedding = tts.audio_embedding( qnt )
# try and extract features from the raw audio itself # try and extract features from the raw audio itself
else: else:
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device) # qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' ) wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' )
duration = wav.shape[-1] / sr
if 0 < min_duration and duration < min_duration:
continue
if 0 < max_duration and max_duration < duration:
continue
if mfcc is None:
mfcc = T.MFCC(sample_rate=cfg.sample_rate)
embedding = mfcc(wav.to(device))[0].t() embedding = mfcc(wav.to(device))[0].t()
if slop:
embedding = embedding.sum(dim=0)
features[filename] = embedding features[filename] = embedding
# rely on FAISS to handle storing embeddings and handling queries
# will probably explode in size fast...........
if storage_backend == "faiss":
import faiss
index = faiss.IndexFlatL2( embeddings.shape[-1] )
embeddings = torch.stack( list( features.values() ) ).cpu()
index.add( embeddings )
""" """
if use_faiss: # to-do: support just querying for list of similar to cram into JSON metadata
if index is None:
shape = embedding.shape
index = faiss.IndexFlatL2(shape[1])
index.add(embedding.cpu())
if verbose: if verbose:
for filename, embedding in features.items(): for filename, embedding in features.items():
D, I = index.search(embedding.cpu(), k=3) D, I = index.search(embedding.unsqueeze(0).cpu(), k=2)
# print(f'{filename}: {I[1:]}') sim = list(I[0][1:])
print(f'{filename}: {sim}')
if metadata_path is not None:
index.save(metadata_path)
""" """
keys = list(features.keys()) if metadata_path is not None:
key_range = range(len(keys)) faiss.write_index(index, str(metadata_path.with_suffix(".faiss")))
# queue = [ (index_a, index_b) for index_b in key_range for index_a in key_range if index_a != index_b ]
queue = list(combinations(key_range, 2)) return
# compute similarities for every utternace """
# to-do: actually refine this, maybe
# desu it's not super easy to install with python3.12, and it is slower than FAISS in testing............
if storage_backend == "chunkdot":
from chunkdot import cosine_similarity_top_k
embeddings = torch.stack( list( features.values() ) ).cpu().numpy()
similarities = cosine_similarity_top_k(embeddings, top_k=8, show_progress=verbose)
print(similarities)
return
"""
metadata = None
if metadata_path is not None:
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else None
keys = list(features.keys())
# do batch cosine similarity processing
if slop:
embeddings = torch.stack( list( features.values() ) )
sorted_similarities = {}
for index, filename in enumerate(keys):
embedding = features[filename].unsqueeze(0)
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
similarities = sorted([ ( keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
sorted_similarities[filename] = similarities
most_filename, most_score = similarities[0]
least_filename, least_score = similarities[-1]
if metadata is not None:
if filename not in metadata:
metadata[filename] = {}
metadata[filename]["similar"] = similarities
if verbose:
print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
if metadata is not None:
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( truncate_json( json.dumps( metadata ) ) )
return sorted_similarities
# an EXTREMELY naive implementation, fucking disgusting
queue = list(combinations(range(len(keys)), 2))
for key in tqdm(queue, desc="Computing similarities", disable=not verbose): for key in tqdm(queue, desc="Computing similarities", disable=not verbose):
index_a, index_b = key index_a, index_b = key
filename_a, filename_b = keys[index_a], keys[index_b] filename_a, filename_b = keys[index_a], keys[index_b]
@ -156,8 +240,17 @@ def process(
similarities[key] = similarities[swapped_key] similarities[key] = similarities[swapped_key]
continue continue
shortest = min( features[filename_a].shape[0], features[filename_b].shape[0] ) if slop:
similarity = torch.nn.functional.cosine_similarity(features[filename_a][:shortest, :], features[filename_b][:shortest, :], dim=1).mean().item() embedding_a = features[filename_a]
embedding_b = features[filename_b]
similarity = torch.nn.functional.cosine_similarity(embedding_a, embedding_b, dim=0).mean().item()
else:
shortest = min( features[filename_a].shape[0], features[filename_b].shape[0] )
embedding_a = features[filename_a][:shortest, :]
embedding_b = features[filename_b][:shortest, :]
similarity = torch.nn.functional.cosine_similarity(embedding_a, embedding_b, dim=1).mean().item()
similarities[key] = similarity similarities[key] = similarity
@ -175,13 +268,6 @@ def process(
if index_a not in sorted_similarities[index_b]: if index_a not in sorted_similarities[index_b]:
sorted_similarities[index_b][index_a] = similarity sorted_similarities[index_b][index_a] = similarity
metadata = None
if metadata_path is not None:
if metadata_path.exists():
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
else:
metadata = {}
# sort similarities scores # sort similarities scores
for key, sorted_similarity in sorted_similarities.items(): for key, sorted_similarity in sorted_similarities.items():
sorted_similarities[key] = sorted([ ( key, similarity ) for key, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True) sorted_similarities[key] = sorted([ ( key, similarity ) for key, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True)
@ -201,9 +287,7 @@ def process(
if metadata is not None: if metadata is not None:
with open(str(metadata_path), "w", encoding="utf-8") as f: with open(str(metadata_path), "w", encoding="utf-8") as f:
serialized = json.dumps( metadata ) f.write( truncate_json( json.dumps( metadata ) ) )
serialized = truncate_json( serialized )
f.write( serialized )
return sorted_similarities return sorted_similarities
@ -215,10 +299,13 @@ def main():
parser.add_argument("--yaml", type=Path) parser.add_argument("--yaml", type=Path)
parser.add_argument("--text", action="store_true") parser.add_argument("--text", action="store_true")
parser.add_argument("--maximum-duration", type=float, default=3.0) parser.add_argument("--trim-duration", type=float, default=3.0)
parser.add_argument("--min-duration", type=float, default=0)
parser.add_argument("--max-duration", type=float, default=0)
parser.add_argument("--storage-backend", type=str, default="slop")
parser.add_argument("--audio-backend", type=str, default="encodec") parser.add_argument("--audio-backend", type=str, default="encodec")
parser.add_argument("--dtype", type=str, default="float16") parser.add_argument("--dtype", type=str, default="float32")
parser.add_argument("--amp", action="store_true") parser.add_argument("--amp", action="store_true")
parser.add_argument("--device", type=str, default="cpu") # unironically faster parser.add_argument("--device", type=str, default="cpu") # unironically faster
@ -236,10 +323,13 @@ def main():
process( process(
speaker_path=cfg.data_dir / speaker_name, speaker_path=cfg.data_dir / speaker_name,
metadata_path=cfg.metadata_dir / f'{speaker_name}.json', metadata_path=cfg.metadata_dir / f'{speaker_name}.faiss',
yaml=args.yaml, yaml=args.yaml,
text=args.text, text=args.text,
maximum_duration=args.maximum_duration, trim_duration=args.trim_duration,
min_duration=args.min_duration,
max_duration=args.max_duration,
storage_backend=args.storage_backend,
audio_backend=args.audio_backend, audio_backend=args.audio_backend,
device=args.device, device=args.device,
@ -260,18 +350,22 @@ def main():
# noise # noise
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'): for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
add( data_dir, type="noise", texts=False ) add( data_dir, type="noise", texts=False )
elif args.input_speaker: elif args.input_speaker:
process( process(
speaker_path=args.input_speaker, speaker_path=args.input_speaker,
yaml=args.yaml, yaml=args.yaml,
text=args.text, text=args.text,
maximum_duration=args.maximum_duration, trim_duration=args.trim_duration,
min_duration=args.min_duration,
max_duration=args.max_duration,
audio_backend=args.audio_backend, audio_backend=args.audio_backend,
device=args.device, device=args.device,
dtype=args.dtype, dtype=args.dtype,
amp=args.amp, amp=args.amp,
storage_backend=args.storage_backend,
verbose=True, verbose=True,
) )
else: else: