diff --git a/vall_e/config.py b/vall_e/config.py index 29d9288..ef6d515 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -144,11 +144,13 @@ class Dataset: phones_range: list[int] = field(default_factory=lambda: [4, 256]) duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) + prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) min_utterances: int = 2 random_utterance: float = 1.0 max_prompts: int = 3 - prompt_duration: float = 3.0 + + prompt_duration: float = 0.0 # legacy max_resps: int = 1 p_resp_append: float = 1.0 @@ -676,6 +678,9 @@ class Config(_Config): if self.hyperparameters.scheduler == "": self.hyperparameters.torch_scheduler = True + if self.dataset.prompt_duration != 0: + self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration] + # Preserves the old behavior class NaiveTokenizer: def get_vocab( self ): diff --git a/vall_e/data.py b/vall_e/data.py index fa61289..bb28d90 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -368,12 +368,8 @@ class Dataset(_Dataset): ) """ - # shuffle it up a bit prom_length = 0 - if cfg.experimental and False: - trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second)) - else: - trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second) + random.randint(-cfg.dataset.frames_per_second, cfg.dataset.frames_per_second) + trim_length = random.randint(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second for _ in range(cfg.dataset.max_prompts): path = random.choice(choices) @@ -388,7 +384,7 @@ class Dataset(_Dataset): else: qnt = _load_quants(path) - if cfg.dataset.prompt_duration > 0 and trim_length < qnt.shape[0]: + if 0 < trim_length and trim_length < qnt.shape[0]: qnt = trim( qnt, trim_length ) prom_list.append(qnt) @@ -401,7 +397,7 @@ class Dataset(_Dataset): # as you technically can't just append encodec sequences together like this without issues prom = torch.cat(prom_list) - if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]: + if 0 < trim_length and trim_length < prom.shape[0]: prom = trim( prom, trim_length ) return prom @@ -474,7 +470,7 @@ class Dataset(_Dataset): resps = torch.concat([ resps, qnt ]) task = "tts" - trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second) + trim_length = random.randint(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps