diff --git a/vall_e/emb/codecs/dac.py b/vall_e/emb/codecs/dac.py index 9f545e4..ff6f660 100644 --- a/vall_e/emb/codecs/dac.py +++ b/vall_e/emb/codecs/dac.py @@ -4,6 +4,8 @@ from dac import DACFile from audiotools import AudioSignal from dac.utils import load_model as __load_dac_model +from typing import Union +from pathlib import Path """ Patch decode to skip things related to the metadata (namely the waveform trimming) So far it seems the raw waveform can just be returned without any post-processing diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 49fb22f..78432bb 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -174,6 +174,7 @@ def process( stride_offset=0, slice="auto", batch_size=1, + max_duration=None, low_memory=False, @@ -326,6 +327,9 @@ def process( start = int((segment['start']-0.05) * sample_rate) end = int((segment['end']+0.5) * sample_rate) + if max_duration and (end - start) / sample_rate > max_duration: + continue + if not presliced: if start < 0: start = 0 @@ -364,6 +368,7 @@ def main(): parser.add_argument("--stride-offset", type=int, default=0) parser.add_argument("--slice", type=str, default="auto") parser.add_argument("--batch-size", type=int, default=0) + parser.add_argument("--max-duration", type=int, default=0) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--dtype", type=str, default="bfloat16") @@ -394,6 +399,7 @@ def main(): stride_offset=args.stride_offset, slice=args.slice, batch_size=args.batch_size, + max_duration=args.max_duration, low_memory=args.low_memory,