don't do eval on stt because it's so slow and I don't even bother doing any metrics against it anyways (to-do: make this a flag)

This commit is contained in:
mrq 2024-09-26 18:56:57 -05:00
parent ff7a1b4163
commit 039482a48e
2 changed files with 27 additions and 26 deletions

View File

@ -43,6 +43,7 @@ def main():
parser.add_argument("--demo-dir", type=Path, default=None) parser.add_argument("--demo-dir", type=Path, default=None)
parser.add_argument("--skip-existing", action="store_true") parser.add_argument("--skip-existing", action="store_true")
parser.add_argument("--sample-from-dataset", action="store_true") parser.add_argument("--sample-from-dataset", action="store_true")
parser.add_argument("--load-from-dataloader", action="store_true")
parser.add_argument("--dataset-samples", type=int, default=0) parser.add_argument("--dataset-samples", type=int, default=0)
parser.add_argument("--audio-path-root", type=str, default=None) parser.add_argument("--audio-path-root", type=str, default=None)
parser.add_argument("--preamble", type=str, default=None) parser.add_argument("--preamble", type=str, default=None)
@ -120,7 +121,7 @@ def main():
samples_dirs["dataset"] = args.demo_dir / "dataset" samples_dirs["dataset"] = args.demo_dir / "dataset"
""" if args.load_from_dataloader:
_logger.info("Loading dataloader...") _logger.info("Loading dataloader...")
dataloader = create_train_dataloader() dataloader = create_train_dataloader()
_logger.info("Loaded dataloader.") _logger.info("Loaded dataloader.")
@ -153,7 +154,7 @@ def main():
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" ) decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" ) decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
"""
for k, sample_dir in samples_dirs.items(): for k, sample_dir in samples_dirs.items():
if not sample_dir.exists(): if not sample_dir.exists():
continue continue

View File

@ -115,7 +115,7 @@ def run_eval(engines, eval_name, dl):
for i, task in enumerate( batch["task"] ): for i, task in enumerate( batch["task"] ):
# easier to just change it to a tts task than drop stt tasks from the batch # easier to just change it to a tts task than drop stt tasks from the batch
if task == "stt": if task == "stt":
has_stt = True # has_stt = True
batch["task"][i] = "tts" batch["task"][i] = "tts"
batch["proms"][i] = batch["resps"][i][:75*3, :] batch["proms"][i] = batch["resps"][i][:75*3, :]