more tweaks
This commit is contained in:
parent
ebac1db16c
commit
84647f588a
|
@ -15,7 +15,7 @@ from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge
|
||||||
from .emb.g2p import encode as encode_phns
|
from .emb.g2p import encode as encode_phns
|
||||||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||||
from .utils.distributed import global_rank, local_rank, world_size
|
from .utils.distributed import global_rank, local_rank, world_size
|
||||||
from .utils.io import torch_save, torch_load, json_read, json_write
|
from .utils.io import torch_save, torch_load, json_read, json_write, json_stringify, json_parse
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import cache, cached_property
|
from functools import cache, cached_property
|
||||||
|
@ -473,7 +473,7 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||||
|
|
||||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||||
#metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
#metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
metadata = json_read( metadata_path )
|
||||||
|
|
||||||
if len(metadata) == 0:
|
if len(metadata) == 0:
|
||||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
|
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
|
||||||
|
@ -554,7 +554,8 @@ def _get_phones(path):
|
||||||
phone_path = _get_phone_path(path)
|
phone_path = _get_phone_path(path)
|
||||||
quant_path = _get_quant_path(path)
|
quant_path = _get_quant_path(path)
|
||||||
if phone_path.exists():
|
if phone_path.exists():
|
||||||
metadata = json.loads(open(phone_path, "r", encoding="utf-8").read())
|
#metadata = json.loads(open(phone_path, "r", encoding="utf-8").read())
|
||||||
|
metadata = json_read(phone_path)
|
||||||
elif quant_path.exists():
|
elif quant_path.exists():
|
||||||
_, metadata = _load_quants( path, return_metadata=True )
|
_, metadata = _load_quants( path, return_metadata=True )
|
||||||
else:
|
else:
|
||||||
|
@ -879,7 +880,7 @@ class Dataset(_Dataset):
|
||||||
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
||||||
if not metadata_path.exists():
|
if not metadata_path.exists():
|
||||||
return None
|
return None
|
||||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
metadata = json_read( metadata_path )
|
||||||
if reference not in metadata:
|
if reference not in metadata:
|
||||||
return None
|
return None
|
||||||
reference_metadata = metadata[reference]
|
reference_metadata = metadata[reference]
|
||||||
|
@ -1335,18 +1336,28 @@ def create_train_val_dataloader():
|
||||||
|
|
||||||
return train_dl, subtrain_dl, val_dl
|
return train_dl, subtrain_dl, val_dl
|
||||||
|
|
||||||
|
# parse metadata from an numpy file (.enc/.dac) and validate it
|
||||||
def process_artifact_metadata( artifact ):
|
def process_artifact_metadata( artifact ):
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
|
# text transcription (just in case)
|
||||||
if "text" in artifact["metadata"]:
|
if "text" in artifact["metadata"]:
|
||||||
metadata["text"] = artifact["metadata"]["text"]
|
metadata["text"] = artifact["metadata"]["text"]
|
||||||
|
# phonemization of text transcription (just in case)
|
||||||
if "phonemes" in artifact["metadata"]:
|
if "phonemes" in artifact["metadata"]:
|
||||||
metadata["phonemes"] = artifact["metadata"]["phonemes"]
|
metadata["phonemes"] = artifact["metadata"]["phonemes"]
|
||||||
|
# language for sampling / input creation
|
||||||
if "language" in artifact["metadata"]:
|
if "language" in artifact["metadata"]:
|
||||||
metadata["language"] = artifact["metadata"]["language"]
|
metadata["language"] = artifact["metadata"]["language"]
|
||||||
if "original_length" in artifact["metadata"] and "sample_rate" in artifact["metadata"]:
|
# top-k similar utterances for this utternace
|
||||||
|
if "similar" in artifact["metadata"]:
|
||||||
|
metadata["similar"] = artifact["metadata"]["similar"]
|
||||||
|
# duration for use of culling / sorting dataset
|
||||||
|
if "duration" in artifact["metadata"]:
|
||||||
|
metadata["duration"] = duration
|
||||||
|
# derive duration from sample count / sample rate
|
||||||
|
elif "original_length" in artifact["metadata"] and "sample_rate" in artifact["metadata"]:
|
||||||
metadata["duration"] = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
|
metadata["duration"] = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
|
||||||
|
|
||||||
# rephonemize if required
|
# rephonemize if required
|
||||||
if "phonemes" not in metadata and "text" in metadata:
|
if "phonemes" not in metadata and "text" in metadata:
|
||||||
metadata["phonemes"] = encode_phns( metadata["text"], language=metadata["language"] if "language" in metadata["language"] else "en" )
|
metadata["phonemes"] = encode_phns( metadata["text"], language=metadata["language"] if "language" in metadata["language"] else "en" )
|
||||||
|
@ -1361,6 +1372,16 @@ def process_artifact_metadata( artifact ):
|
||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
# yucky, but I would like to have the LibriTTS-R utterances remapped to their LibriSpeech counterpart
|
||||||
|
# to-do: allow this to be adjusted without having to regenerate metadata / HDF5 by remapping name during dataloader creation
|
||||||
|
def remap_speaker_name( name ):
|
||||||
|
# commented out because I don't want the LibriSpeech portion of the dataset to get added
|
||||||
|
"""
|
||||||
|
if "LibriTTS-R" in speaker_name:
|
||||||
|
name = name.replace("LibriTTS-R", "LibriVox")
|
||||||
|
"""
|
||||||
|
return name
|
||||||
|
|
||||||
# parse dataset into better to sample metadata
|
# parse dataset into better to sample metadata
|
||||||
def create_dataset_metadata( skip_existing=True ):
|
def create_dataset_metadata( skip_existing=True ):
|
||||||
symmap = get_phone_symmap()
|
symmap = get_phone_symmap()
|
||||||
|
@ -1373,25 +1394,16 @@ def create_dataset_metadata( skip_existing=True ):
|
||||||
def add( dir, type="training", audios=True, texts=True ):
|
def add( dir, type="training", audios=True, texts=True ):
|
||||||
name = str(dir)
|
name = str(dir)
|
||||||
name = name.replace(root, "")
|
name = name.replace(root, "")
|
||||||
|
speaker_name = remap_speaker_name( name )
|
||||||
speaker_name = name
|
|
||||||
"""
|
|
||||||
if "LibriTTS-R" in speaker_name:
|
|
||||||
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
|
||||||
"""
|
|
||||||
|
|
||||||
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
||||||
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
try:
|
metadata = json_read( metadata_path, default={} )
|
||||||
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
|
||||||
except Exception as e:
|
|
||||||
metadata = {}
|
|
||||||
|
|
||||||
if not os.path.isdir(f'{root}/{name}/'):
|
if not os.path.isdir(f'{root}/{name}/'):
|
||||||
return
|
return
|
||||||
|
|
||||||
# tqdm.write(f'{root}/{name}')
|
|
||||||
files = os.listdir(f'{root}/{name}/')
|
files = os.listdir(f'{root}/{name}/')
|
||||||
|
|
||||||
# grab IDs for every file
|
# grab IDs for every file
|
||||||
|
@ -1430,8 +1442,7 @@ def create_dataset_metadata( skip_existing=True ):
|
||||||
tqdm.write(f'Error while processing {id}: {e}')
|
tqdm.write(f'Error while processing {id}: {e}')
|
||||||
|
|
||||||
if wrote:
|
if wrote:
|
||||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
json_write( metadata, metadata_path )
|
||||||
f.write( json.dumps( metadata ) )
|
|
||||||
|
|
||||||
# training
|
# training
|
||||||
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
||||||
|
@ -1460,16 +1471,12 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
def add( dir, type="training", audios=True, texts=True ):
|
def add( dir, type="training", audios=True, texts=True ):
|
||||||
name = str(dir)
|
name = str(dir)
|
||||||
name = name.replace(root, "")
|
name = name.replace(root, "")
|
||||||
|
speaker_name = remap_speaker_name( name )
|
||||||
# yucky
|
|
||||||
speaker_name = name
|
|
||||||
if "LibriTTS-R" in speaker_name:
|
|
||||||
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
|
||||||
|
|
||||||
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
||||||
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
metadata = json_read(metadata_path, default={})
|
||||||
|
|
||||||
if not os.path.isdir(f'{root}/{name}/'):
|
if not os.path.isdir(f'{root}/{name}/'):
|
||||||
return
|
return
|
||||||
|
@ -1534,9 +1541,11 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
|
group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
|
||||||
|
|
||||||
# text
|
# text
|
||||||
|
# this is a relic from when I did have the quantized audio and phoneme transcription separate
|
||||||
|
# to-do: ensure I can remove this block
|
||||||
if texts:
|
if texts:
|
||||||
if not utterance_metadata and text_exists:
|
if not utterance_metadata and text_exists:
|
||||||
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
utterance_metadata = json_read(f'{root}/{name}/{id}{_get_phone_extension()}')
|
||||||
|
|
||||||
phn = "".join(utterance_metadata["phonemes"])
|
phn = "".join(utterance_metadata["phonemes"])
|
||||||
phn = cfg.tokenizer.encode(phn)
|
phn = cfg.tokenizer.encode(phn)
|
||||||
|
@ -1552,8 +1561,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tqdm.write(f'Error while processing {id}: {e}')
|
tqdm.write(f'Error while processing {id}: {e}')
|
||||||
|
|
||||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
json_write( metadata, metadata_path )
|
||||||
f.write( json.dumps( metadata ) )
|
|
||||||
|
|
||||||
# training
|
# training
|
||||||
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
||||||
|
@ -1571,7 +1579,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
if "symmap" in hf:
|
if "symmap" in hf:
|
||||||
del hf['symmap']
|
del hf['symmap']
|
||||||
|
|
||||||
hf.create_dataset('symmap', data=json.dumps(symmap))
|
hf.create_dataset('symmap', data=json_stringify(symmap))
|
||||||
hf.close()
|
hf.close()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1596,7 +1604,7 @@ if __name__ == "__main__":
|
||||||
continue
|
continue
|
||||||
dataset.append(f'{group}/{name}')
|
dataset.append(f'{group}/{name}')
|
||||||
|
|
||||||
_logger.info(json.dumps(dataset))
|
_logger.info(json_stringify(dataset))
|
||||||
elif args.action == "metadata":
|
elif args.action == "metadata":
|
||||||
create_dataset_metadata()
|
create_dataset_metadata()
|
||||||
elif args.action == "sample":
|
elif args.action == "sample":
|
||||||
|
|
|
@ -54,6 +54,7 @@ def process(
|
||||||
verbose=False,
|
verbose=False,
|
||||||
metadata_path=None,
|
metadata_path=None,
|
||||||
top_k=8,
|
top_k=8,
|
||||||
|
metadata_keys=[],
|
||||||
|
|
||||||
trim_duration=0,
|
trim_duration=0,
|
||||||
min_duration=0,
|
min_duration=0,
|
||||||
|
@ -73,13 +74,16 @@ def process(
|
||||||
if tts is None:
|
if tts is None:
|
||||||
tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype )
|
tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype )
|
||||||
|
|
||||||
features = {}
|
features = { key: None for key in metadata_keys }
|
||||||
|
|
||||||
mfcc = None
|
mfcc = None
|
||||||
|
|
||||||
simplified_metadata = True # aims to slim down the raw data in the JSON to store
|
simplified_metadata = True # aims to slim down the raw data in the JSON to store
|
||||||
slop = True # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier
|
slop = True # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier
|
||||||
|
|
||||||
|
if not speaker_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
# compute features (embeddings if quantized already, MFCC features if raw audio)
|
# compute features (embeddings if quantized already, MFCC features if raw audio)
|
||||||
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path.name}'", disable=not verbose):
|
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path.name}'", disable=not verbose):
|
||||||
extension = filename.split(".")[-1]
|
extension = filename.split(".")[-1]
|
||||||
|
@ -92,11 +96,13 @@ def process(
|
||||||
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
|
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
|
||||||
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
|
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
|
||||||
|
|
||||||
|
"""
|
||||||
if 0 < min_duration and duration < min_duration:
|
if 0 < min_duration and duration < min_duration:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if 0 < max_duration and max_duration < duration:
|
if 0 < max_duration and max_duration < duration:
|
||||||
continue
|
continue
|
||||||
|
"""
|
||||||
|
|
||||||
lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en"
|
lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en"
|
||||||
if "phonemes" in artifact["metadata"]:
|
if "phonemes" in artifact["metadata"]:
|
||||||
|
@ -178,23 +184,36 @@ def process(
|
||||||
# do batch cosine similarity processing
|
# do batch cosine similarity processing
|
||||||
|
|
||||||
keys = list(features.keys())
|
keys = list(features.keys())
|
||||||
embeddings = torch.stack( list( features.values() ) )
|
top_k = min( top_k, len(keys) )
|
||||||
|
|
||||||
|
if top_k == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
null_embedding = torch.zeros( (1024,), device=tts.device, dtype=tts.dtype )
|
||||||
|
embeddings = torch.stack( [ feature if feature is not None else null_embedding for feature in features.values() ] )
|
||||||
sorted_similarities = {}
|
sorted_similarities = {}
|
||||||
|
|
||||||
|
|
||||||
for index, filename in tqdm(enumerate(keys), total=len(keys), desc=f"Computing similarities: {speaker_path.name}"):
|
for index, filename in tqdm(enumerate(keys), total=len(keys), desc=f"Computing similarities: {speaker_path.name}"):
|
||||||
|
if features[filename] is None:
|
||||||
|
continue
|
||||||
|
|
||||||
embedding = features[filename].unsqueeze(0)
|
embedding = features[filename].unsqueeze(0)
|
||||||
|
|
||||||
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1)
|
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1)
|
||||||
# set current index to -inf
|
|
||||||
similarities[index] = float("-inf")
|
|
||||||
similarities = torch.topk(similarities, k=top_k, largest=True, sorted=True).indices.tolist()
|
|
||||||
# similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
|
|
||||||
|
|
||||||
sorted_similarities[filename] = similarities
|
|
||||||
|
|
||||||
# sorting is slow, don't bother
|
# sorting is slow, don't bother
|
||||||
#sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
|
#sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# set current index to -inf
|
||||||
|
similarities[index] = float("-inf")
|
||||||
|
|
||||||
|
topk = torch.topk(similarities, k=top_k, largest=True, sorted=True)
|
||||||
|
similarities = [ (index, keys[index], score) for index, score in zip( topk.indices.tolist(), topk.values.tolist() ) ]
|
||||||
|
|
||||||
|
sorted_similarities[filename] = similarities
|
||||||
|
|
||||||
|
|
||||||
return sorted_similarities
|
return sorted_similarities
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -221,6 +240,8 @@ def main():
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args.skip_existing = False #
|
||||||
|
|
||||||
if args.use_dataset:
|
if args.use_dataset:
|
||||||
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
|
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@ -228,10 +249,17 @@ def main():
|
||||||
name = str(dir)
|
name = str(dir)
|
||||||
name = name.replace(str(cfg.data_dir), "")
|
name = name.replace(str(cfg.data_dir), "")
|
||||||
speaker_name = name
|
speaker_name = name
|
||||||
|
"""
|
||||||
if "LibriTTS-R" in speaker_name:
|
if "LibriTTS-R" in speaker_name:
|
||||||
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
||||||
|
"""
|
||||||
|
|
||||||
metadata_path = cfg.metadata_dir / f'{speaker_name}.json'
|
metadata_path = cfg.metadata_dir / f'{speaker_name}.json'
|
||||||
|
metadata = json_read( metadata_path, default={} )
|
||||||
|
metadata_keys = list(metadata.keys()) if metadata else []
|
||||||
|
|
||||||
|
if args.skip_existing and metadata_keys and "similar" in metadata[metadata_keys[-1]]:
|
||||||
|
return
|
||||||
|
|
||||||
similarities = process(
|
similarities = process(
|
||||||
speaker_path=cfg.data_dir / speaker_name,
|
speaker_path=cfg.data_dir / speaker_name,
|
||||||
|
@ -242,6 +270,7 @@ def main():
|
||||||
#min_duration=args.min_duration,
|
#min_duration=args.min_duration,
|
||||||
#max_duration=args.max_duration,
|
#max_duration=args.max_duration,
|
||||||
storage_backend=args.storage_backend,
|
storage_backend=args.storage_backend,
|
||||||
|
metadata_keys=metadata_keys,
|
||||||
|
|
||||||
audio_backend=args.audio_backend,
|
audio_backend=args.audio_backend,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
|
@ -251,28 +280,22 @@ def main():
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not similarities:
|
||||||
|
return
|
||||||
|
|
||||||
if args.storage_backend == "faiss":
|
if args.storage_backend == "faiss":
|
||||||
faiss.write_index(similarities, str(metadata_path.with_suffix(".faiss")))
|
faiss.write_index(similarities, str(metadata_path.with_suffix(".faiss")))
|
||||||
return
|
return
|
||||||
|
|
||||||
#metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else {}
|
for filename, similar in similarities.items():
|
||||||
metadata = json_read( metadata_path, default={} )
|
|
||||||
metadata_keys = list(metadata.keys()) if metadata else list(similarities.keys())
|
|
||||||
|
|
||||||
for filename, sim in similarities.items():
|
|
||||||
if filename not in metadata:
|
if filename not in metadata:
|
||||||
metadata[filename] = {}
|
metadata[filename] = {}
|
||||||
|
|
||||||
metadata[filename]["similar"] = sim
|
# overkill but i'm very paranoid about mismatching indices
|
||||||
|
metadata[filename]["similar"] = [ metadata_keys.index(s[1]) for s in similar ]
|
||||||
|
|
||||||
json_write( metadata, metadata_path )
|
json_write( metadata, metadata_path )
|
||||||
|
|
||||||
"""
|
|
||||||
with open(str(metadata_path), "wb") as f:
|
|
||||||
f.write( json.dumps( metadata ) )
|
|
||||||
#f.write( truncate_json( json.dumps( metadata ) ) )
|
|
||||||
"""
|
|
||||||
|
|
||||||
# training
|
# training
|
||||||
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
||||||
add( data_dir, type="training" )
|
add( data_dir, type="training" )
|
||||||
|
@ -286,7 +309,7 @@ def main():
|
||||||
add( data_dir, type="noise", texts=False )
|
add( data_dir, type="noise", texts=False )
|
||||||
|
|
||||||
elif args.input_speaker:
|
elif args.input_speaker:
|
||||||
process(
|
similarities = process(
|
||||||
speaker_path=args.input_speaker,
|
speaker_path=args.input_speaker,
|
||||||
yaml=args.yaml,
|
yaml=args.yaml,
|
||||||
text=args.text,
|
text=args.text,
|
||||||
|
@ -304,6 +327,10 @@ def main():
|
||||||
storage_backend=args.storage_backend,
|
storage_backend=args.storage_backend,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# and print
|
||||||
|
for filename, sim in similarities.items():
|
||||||
|
print(f'{filename}: {sim}')
|
||||||
else:
|
else:
|
||||||
raise Exception("!")
|
raise Exception("!")
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,11 @@ try:
|
||||||
except:
|
except:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
def json_stringify( data ):
|
from .utils import truncate_json
|
||||||
|
|
||||||
|
def json_stringify( data, truncate=False ):
|
||||||
|
if truncate:
|
||||||
|
return truncate_json( json.dumps( data ) )
|
||||||
return json.dumps( data )
|
return json.dumps( data )
|
||||||
|
|
||||||
def json_parse( string ):
|
def json_parse( string ):
|
||||||
|
@ -26,11 +30,11 @@ def json_read( path, default=None ):
|
||||||
with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f:
|
with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f:
|
||||||
return json_parse( f.read() )
|
return json_parse( f.read() )
|
||||||
|
|
||||||
def json_write( data, path ):
|
def json_write( data, path, truncate=False ):
|
||||||
path = coerce_path( path )
|
path = coerce_path( path )
|
||||||
|
|
||||||
with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f:
|
with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f:
|
||||||
f.write( json_stringify( data ) )
|
f.write( json_stringify( data, truncate=truncate ) )
|
||||||
|
|
||||||
def coerce_path( path ):
|
def coerce_path( path ):
|
||||||
return path if isinstance( path, Path ) else Path(path)
|
return path if isinstance( path, Path ) else Path(path)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user