don't do eval on stt because it's so slow and I don't even bother doing any metrics against it anyways (to-do: make this a flag)
This commit is contained in:
parent
ff7a1b4163
commit
039482a48e
|
@ -43,6 +43,7 @@ def main():
|
||||||
parser.add_argument("--demo-dir", type=Path, default=None)
|
parser.add_argument("--demo-dir", type=Path, default=None)
|
||||||
parser.add_argument("--skip-existing", action="store_true")
|
parser.add_argument("--skip-existing", action="store_true")
|
||||||
parser.add_argument("--sample-from-dataset", action="store_true")
|
parser.add_argument("--sample-from-dataset", action="store_true")
|
||||||
|
parser.add_argument("--load-from-dataloader", action="store_true")
|
||||||
parser.add_argument("--dataset-samples", type=int, default=0)
|
parser.add_argument("--dataset-samples", type=int, default=0)
|
||||||
parser.add_argument("--audio-path-root", type=str, default=None)
|
parser.add_argument("--audio-path-root", type=str, default=None)
|
||||||
parser.add_argument("--preamble", type=str, default=None)
|
parser.add_argument("--preamble", type=str, default=None)
|
||||||
|
@ -120,40 +121,40 @@ def main():
|
||||||
|
|
||||||
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
||||||
|
|
||||||
"""
|
if args.load_from_dataloader:
|
||||||
_logger.info("Loading dataloader...")
|
_logger.info("Loading dataloader...")
|
||||||
dataloader = create_train_dataloader()
|
dataloader = create_train_dataloader()
|
||||||
_logger.info("Loaded dataloader.")
|
_logger.info("Loaded dataloader.")
|
||||||
|
|
||||||
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
|
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
|
||||||
|
|
||||||
length = len( dataloader.dataset )
|
length = len( dataloader.dataset )
|
||||||
for i in trange( num, desc="Sampling dataset for samples" ):
|
for i in trange( num, desc="Sampling dataset for samples" ):
|
||||||
idx = random.randint( 0, length )
|
idx = random.randint( 0, length )
|
||||||
batch = dataloader.dataset[idx]
|
batch = dataloader.dataset[idx]
|
||||||
|
|
||||||
dir = args.demo_dir / "dataset" / f'{i}'
|
dir = args.demo_dir / "dataset" / f'{i}'
|
||||||
|
|
||||||
(dir / "out").mkdir(parents=True, exist_ok=True)
|
(dir / "out").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
metadata = batch["metadata"]
|
metadata = batch["metadata"]
|
||||||
|
|
||||||
text = metadata["text"]
|
text = metadata["text"]
|
||||||
language = metadata["language"]
|
language = metadata["language"]
|
||||||
|
|
||||||
prompt = dir / "prompt.wav"
|
prompt = dir / "prompt.wav"
|
||||||
reference = dir / "reference.wav"
|
reference = dir / "reference.wav"
|
||||||
out_path = dir / "out" / "ours.wav"
|
out_path = dir / "out" / "ours.wav"
|
||||||
|
|
||||||
if args.skip_existing and out_path.exists():
|
if args.skip_existing and out_path.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text )
|
open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text )
|
||||||
open( dir / "language.txt", "w", encoding="utf-8" ).write( language )
|
open( dir / "language.txt", "w", encoding="utf-8" ).write( language )
|
||||||
|
|
||||||
|
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
|
||||||
|
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
|
||||||
|
|
||||||
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
|
|
||||||
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
|
|
||||||
"""
|
|
||||||
for k, sample_dir in samples_dirs.items():
|
for k, sample_dir in samples_dirs.items():
|
||||||
if not sample_dir.exists():
|
if not sample_dir.exists():
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -115,7 +115,7 @@ def run_eval(engines, eval_name, dl):
|
||||||
for i, task in enumerate( batch["task"] ):
|
for i, task in enumerate( batch["task"] ):
|
||||||
# easier to just change it to a tts task than drop stt tasks from the batch
|
# easier to just change it to a tts task than drop stt tasks from the batch
|
||||||
if task == "stt":
|
if task == "stt":
|
||||||
has_stt = True
|
# has_stt = True
|
||||||
batch["task"][i] = "tts"
|
batch["task"][i] = "tts"
|
||||||
batch["proms"][i] = batch["resps"][i][:75*3, :]
|
batch["proms"][i] = batch["resps"][i][:75*3, :]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user