From 74e531d391a2d2b636141364c8942851a69dd242 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 18 May 2024 12:02:56 -0500 Subject: [PATCH] ugh --- scripts/process_dataset.py | 2 +- vall_e/emb/qnt.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/process_dataset.py b/scripts/process_dataset.py index 7f1d00a..c57b933 100644 --- a/scripts/process_dataset.py +++ b/scripts/process_dataset.py @@ -144,7 +144,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): i = i + 1 outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}') - text = metadata[filename]["text"] + text = segment["text"] if len(text) == 0: continue diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 55cc61b..12d9d89 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -285,7 +285,8 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod wav = convert_audio(wav, sr, model.sample_rate, model.channels) wav = wav.to(device) - encoded_frames = model.encode(wav) + with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): + encoded_frames = model.encode(wav) qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t) return qnt