solved my problem
This commit is contained in:
parent
8f41d1b324
commit
be22b65300
|
@ -24,7 +24,6 @@ import torchaudio.transforms as T
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
from ..utils import truncate_json
|
from ..utils import truncate_json
|
||||||
|
|
||||||
# need to validate if this is safe to import before modifying the config
|
|
||||||
from .g2p import encode as phonemize
|
from .g2p import encode as phonemize
|
||||||
from .qnt import encode as quantize, trim, convert_audio
|
from .qnt import encode as quantize, trim, convert_audio
|
||||||
|
|
||||||
|
@ -55,8 +54,11 @@ def process(
|
||||||
verbose=False,
|
verbose=False,
|
||||||
metadata_path=None,
|
metadata_path=None,
|
||||||
|
|
||||||
maximum_duration=0,
|
trim_duration=0,
|
||||||
#use_faiss=True,
|
min_duration=0,
|
||||||
|
max_duration=0,
|
||||||
|
|
||||||
|
storage_backend="local"
|
||||||
):
|
):
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
|
@ -75,14 +77,15 @@ def process(
|
||||||
similarities = {}
|
similarities = {}
|
||||||
sorted_similarities = {}
|
sorted_similarities = {}
|
||||||
|
|
||||||
mfcc = T.MFCC(sample_rate=cfg.sample_rate)
|
mfcc = None
|
||||||
|
|
||||||
"""
|
slop = False # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier
|
||||||
# too slow
|
if storage_backend == "faiss":
|
||||||
if use_faiss:
|
slop = True
|
||||||
import faiss
|
elif storage_backend == "chunkdot":
|
||||||
index = None
|
slop = True
|
||||||
"""
|
elif storage_backend == "slop":
|
||||||
|
slop = True
|
||||||
|
|
||||||
# compute features (embeddings if quantized already, MFCC features if raw audio)
|
# compute features (embeddings if quantized already, MFCC features if raw audio)
|
||||||
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose):
|
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose):
|
||||||
|
@ -94,6 +97,13 @@ def process(
|
||||||
raise Exception("!")
|
raise Exception("!")
|
||||||
|
|
||||||
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
|
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
|
||||||
|
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
|
||||||
|
|
||||||
|
if 0 < min_duration and duration < min_duration:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if 0 < max_duration and max_duration < duration:
|
||||||
|
continue
|
||||||
|
|
||||||
lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en"
|
lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en"
|
||||||
if "phonemes" in artifact["metadata"]:
|
if "phonemes" in artifact["metadata"]:
|
||||||
|
@ -111,42 +121,116 @@ def process(
|
||||||
# treat embeddings as features, if provided quantized audio
|
# treat embeddings as features, if provided quantized audio
|
||||||
if extension in artifact_extension:
|
if extension in artifact_extension:
|
||||||
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
|
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
|
||||||
|
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
|
||||||
|
|
||||||
|
if 0 < min_duration and duration < min_duration:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if 0 < max_duration and max_duration < duration:
|
||||||
|
continue
|
||||||
|
|
||||||
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
|
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
|
||||||
if maximum_duration > 0:
|
|
||||||
qnt = trim( qnt, int( cfg.dataset.frames_per_second * maximum_duration ) )
|
if trim_duration > 0:
|
||||||
|
qnt = trim( qnt, int( cfg.dataset.frames_per_second * trim_duration ) )
|
||||||
|
|
||||||
embedding = tts.audio_embedding( qnt )
|
embedding = tts.audio_embedding( qnt )
|
||||||
# try and extract features from the raw audio itself
|
# try and extract features from the raw audio itself
|
||||||
else:
|
else:
|
||||||
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
|
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
|
||||||
wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' )
|
wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' )
|
||||||
|
|
||||||
|
duration = wav.shape[-1] / sr
|
||||||
|
|
||||||
|
if 0 < min_duration and duration < min_duration:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if 0 < max_duration and max_duration < duration:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mfcc is None:
|
||||||
|
mfcc = T.MFCC(sample_rate=cfg.sample_rate)
|
||||||
|
|
||||||
embedding = mfcc(wav.to(device))[0].t()
|
embedding = mfcc(wav.to(device))[0].t()
|
||||||
|
|
||||||
|
if slop:
|
||||||
|
embedding = embedding.sum(dim=0)
|
||||||
|
|
||||||
features[filename] = embedding
|
features[filename] = embedding
|
||||||
|
|
||||||
|
# rely on FAISS to handle storing embeddings and handling queries
|
||||||
|
# will probably explode in size fast...........
|
||||||
|
if storage_backend == "faiss":
|
||||||
|
import faiss
|
||||||
|
|
||||||
|
index = faiss.IndexFlatL2( embeddings.shape[-1] )
|
||||||
|
embeddings = torch.stack( list( features.values() ) ).cpu()
|
||||||
|
index.add( embeddings )
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if use_faiss:
|
# to-do: support just querying for list of similar to cram into JSON metadata
|
||||||
if index is None:
|
|
||||||
shape = embedding.shape
|
|
||||||
index = faiss.IndexFlatL2(shape[1])
|
|
||||||
|
|
||||||
index.add(embedding.cpu())
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
for filename, embedding in features.items():
|
for filename, embedding in features.items():
|
||||||
D, I = index.search(embedding.cpu(), k=3)
|
D, I = index.search(embedding.unsqueeze(0).cpu(), k=2)
|
||||||
# print(f'{filename}: {I[1:]}')
|
sim = list(I[0][1:])
|
||||||
|
print(f'{filename}: {sim}')
|
||||||
if metadata_path is not None:
|
|
||||||
index.save(metadata_path)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
keys = list(features.keys())
|
if metadata_path is not None:
|
||||||
key_range = range(len(keys))
|
faiss.write_index(index, str(metadata_path.with_suffix(".faiss")))
|
||||||
# queue = [ (index_a, index_b) for index_b in key_range for index_a in key_range if index_a != index_b ]
|
|
||||||
queue = list(combinations(key_range, 2))
|
|
||||||
|
|
||||||
# compute similarities for every utternace
|
return
|
||||||
|
|
||||||
|
"""
|
||||||
|
# to-do: actually refine this, maybe
|
||||||
|
# desu it's not super easy to install with python3.12, and it is slower than FAISS in testing............
|
||||||
|
if storage_backend == "chunkdot":
|
||||||
|
from chunkdot import cosine_similarity_top_k
|
||||||
|
|
||||||
|
embeddings = torch.stack( list( features.values() ) ).cpu().numpy()
|
||||||
|
similarities = cosine_similarity_top_k(embeddings, top_k=8, show_progress=verbose)
|
||||||
|
|
||||||
|
print(similarities)
|
||||||
|
return
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = None
|
||||||
|
if metadata_path is not None:
|
||||||
|
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else None
|
||||||
|
|
||||||
|
keys = list(features.keys())
|
||||||
|
|
||||||
|
# do batch cosine similarity processing
|
||||||
|
if slop:
|
||||||
|
embeddings = torch.stack( list( features.values() ) )
|
||||||
|
sorted_similarities = {}
|
||||||
|
|
||||||
|
for index, filename in enumerate(keys):
|
||||||
|
embedding = features[filename].unsqueeze(0)
|
||||||
|
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
|
||||||
|
similarities = sorted([ ( keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
sorted_similarities[filename] = similarities
|
||||||
|
|
||||||
|
most_filename, most_score = similarities[0]
|
||||||
|
least_filename, least_score = similarities[-1]
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
if filename not in metadata:
|
||||||
|
metadata[filename] = {}
|
||||||
|
metadata[filename]["similar"] = similarities
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||||
|
f.write( truncate_json( json.dumps( metadata ) ) )
|
||||||
|
|
||||||
|
return sorted_similarities
|
||||||
|
|
||||||
|
# an EXTREMELY naive implementation, fucking disgusting
|
||||||
|
queue = list(combinations(range(len(keys)), 2))
|
||||||
for key in tqdm(queue, desc="Computing similarities", disable=not verbose):
|
for key in tqdm(queue, desc="Computing similarities", disable=not verbose):
|
||||||
index_a, index_b = key
|
index_a, index_b = key
|
||||||
filename_a, filename_b = keys[index_a], keys[index_b]
|
filename_a, filename_b = keys[index_a], keys[index_b]
|
||||||
|
@ -156,8 +240,17 @@ def process(
|
||||||
similarities[key] = similarities[swapped_key]
|
similarities[key] = similarities[swapped_key]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
shortest = min( features[filename_a].shape[0], features[filename_b].shape[0] )
|
if slop:
|
||||||
similarity = torch.nn.functional.cosine_similarity(features[filename_a][:shortest, :], features[filename_b][:shortest, :], dim=1).mean().item()
|
embedding_a = features[filename_a]
|
||||||
|
embedding_b = features[filename_b]
|
||||||
|
|
||||||
|
similarity = torch.nn.functional.cosine_similarity(embedding_a, embedding_b, dim=0).mean().item()
|
||||||
|
else:
|
||||||
|
shortest = min( features[filename_a].shape[0], features[filename_b].shape[0] )
|
||||||
|
embedding_a = features[filename_a][:shortest, :]
|
||||||
|
embedding_b = features[filename_b][:shortest, :]
|
||||||
|
|
||||||
|
similarity = torch.nn.functional.cosine_similarity(embedding_a, embedding_b, dim=1).mean().item()
|
||||||
|
|
||||||
similarities[key] = similarity
|
similarities[key] = similarity
|
||||||
|
|
||||||
|
@ -175,13 +268,6 @@ def process(
|
||||||
if index_a not in sorted_similarities[index_b]:
|
if index_a not in sorted_similarities[index_b]:
|
||||||
sorted_similarities[index_b][index_a] = similarity
|
sorted_similarities[index_b][index_a] = similarity
|
||||||
|
|
||||||
metadata = None
|
|
||||||
if metadata_path is not None:
|
|
||||||
if metadata_path.exists():
|
|
||||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
|
||||||
else:
|
|
||||||
metadata = {}
|
|
||||||
|
|
||||||
# sort similarities scores
|
# sort similarities scores
|
||||||
for key, sorted_similarity in sorted_similarities.items():
|
for key, sorted_similarity in sorted_similarities.items():
|
||||||
sorted_similarities[key] = sorted([ ( key, similarity ) for key, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True)
|
sorted_similarities[key] = sorted([ ( key, similarity ) for key, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True)
|
||||||
|
@ -201,9 +287,7 @@ def process(
|
||||||
|
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||||
serialized = json.dumps( metadata )
|
f.write( truncate_json( json.dumps( metadata ) ) )
|
||||||
serialized = truncate_json( serialized )
|
|
||||||
f.write( serialized )
|
|
||||||
|
|
||||||
return sorted_similarities
|
return sorted_similarities
|
||||||
|
|
||||||
|
@ -215,10 +299,13 @@ def main():
|
||||||
|
|
||||||
parser.add_argument("--yaml", type=Path)
|
parser.add_argument("--yaml", type=Path)
|
||||||
parser.add_argument("--text", action="store_true")
|
parser.add_argument("--text", action="store_true")
|
||||||
parser.add_argument("--maximum-duration", type=float, default=3.0)
|
parser.add_argument("--trim-duration", type=float, default=3.0)
|
||||||
|
parser.add_argument("--min-duration", type=float, default=0)
|
||||||
|
parser.add_argument("--max-duration", type=float, default=0)
|
||||||
|
parser.add_argument("--storage-backend", type=str, default="slop")
|
||||||
|
|
||||||
parser.add_argument("--audio-backend", type=str, default="encodec")
|
parser.add_argument("--audio-backend", type=str, default="encodec")
|
||||||
parser.add_argument("--dtype", type=str, default="float16")
|
parser.add_argument("--dtype", type=str, default="float32")
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
parser.add_argument("--device", type=str, default="cpu") # unironically faster
|
parser.add_argument("--device", type=str, default="cpu") # unironically faster
|
||||||
|
|
||||||
|
@ -236,10 +323,13 @@ def main():
|
||||||
|
|
||||||
process(
|
process(
|
||||||
speaker_path=cfg.data_dir / speaker_name,
|
speaker_path=cfg.data_dir / speaker_name,
|
||||||
metadata_path=cfg.metadata_dir / f'{speaker_name}.json',
|
metadata_path=cfg.metadata_dir / f'{speaker_name}.faiss',
|
||||||
yaml=args.yaml,
|
yaml=args.yaml,
|
||||||
text=args.text,
|
text=args.text,
|
||||||
maximum_duration=args.maximum_duration,
|
trim_duration=args.trim_duration,
|
||||||
|
min_duration=args.min_duration,
|
||||||
|
max_duration=args.max_duration,
|
||||||
|
storage_backend=args.storage_backend,
|
||||||
|
|
||||||
audio_backend=args.audio_backend,
|
audio_backend=args.audio_backend,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
|
@ -260,18 +350,22 @@ def main():
|
||||||
# noise
|
# noise
|
||||||
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
|
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
|
||||||
add( data_dir, type="noise", texts=False )
|
add( data_dir, type="noise", texts=False )
|
||||||
|
|
||||||
elif args.input_speaker:
|
elif args.input_speaker:
|
||||||
process(
|
process(
|
||||||
speaker_path=args.input_speaker,
|
speaker_path=args.input_speaker,
|
||||||
yaml=args.yaml,
|
yaml=args.yaml,
|
||||||
text=args.text,
|
text=args.text,
|
||||||
maximum_duration=args.maximum_duration,
|
trim_duration=args.trim_duration,
|
||||||
|
min_duration=args.min_duration,
|
||||||
|
max_duration=args.max_duration,
|
||||||
|
|
||||||
audio_backend=args.audio_backend,
|
audio_backend=args.audio_backend,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
amp=args.amp,
|
amp=args.amp,
|
||||||
|
|
||||||
|
storage_backend=args.storage_backend,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user