This commit is contained in:
mrq 2024-09-18 21:40:57 -05:00
parent fe241f6a99
commit c8d4716a9f
2 changed files with 4 additions and 3 deletions

View File

@ -195,7 +195,7 @@ def process(
sorted_similarities = {} sorted_similarities = {}
for index, filename in tqdm(enumerate(keys), total=len(keys), desc=f"Computing similarities: {speaker_path.name}"): for index, filename in tqdm(enumerate(keys), total=len(keys), desc=f"Computing similarities: {speaker_path.name}", disable=not verbose):
if features[filename] is None: if features[filename] is None:
continue continue
@ -241,7 +241,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
args.skip_existing = False # args.skip_existing = True #
if args.use_dataset: if args.use_dataset:
cfg.metadata_dir.mkdir(parents=True, exist_ok=True) cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
@ -278,7 +278,7 @@ def main():
dtype=args.dtype, dtype=args.dtype,
amp=args.amp, amp=args.amp,
verbose=True, verbose=False,
) )
if not similarities: if not similarities:

View File

@ -188,6 +188,7 @@ def load_engines(training=True, **model_kwargs):
keys = [ keys = [
("text_emb.weight", model.config.text_tokens ), ("text_emb.weight", model.config.text_tokens ),
("tasks_emb.weight", model.config.tasks ), ("tasks_emb.weight", model.config.tasks ),
("langs_emb.weight", model.config.langs ),
("rvq_l_emb.weight", model.config.resp_levels + (1 if "len" in model.config.capabilities else 0) ), ("rvq_l_emb.weight", model.config.resp_levels + (1 if "len" in model.config.capabilities else 0) ),
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ), ("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ), ("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),