From 682e4387dc8325f90af1a639d65c3d6ed6c362bd Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 25 Jul 2024 12:39:57 -0500 Subject: [PATCH] oops (fixed proms being erased from a config oversight) --- vall_e/config.py | 5 +++-- vall_e/data.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 6bfcd52..5b64a83 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -152,7 +152,7 @@ class Dataset: random_utterance: float = 1.0 max_prompts: int = 3 - prompt_duration: float = 0.0 # legacy + prompt_duration: float | None = None # legacy max_resps: int = 1 p_resp_append: float = 1.0 @@ -839,7 +839,8 @@ class Config(BaseConfig): if self.hyperparameters.scheduler == "": self.hyperparameters.torch_scheduler = True - self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration] + if self.dataset.prompt_duration is not None: + self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration] if self.trainer.backend == "local" and self.distributed: self.trainer.ddp = True diff --git a/vall_e/data.py b/vall_e/data.py index 5ec4c6a..fb6d1c1 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -434,7 +434,7 @@ class Dataset(_Dataset): self.training = training self.dataset_type = "training" if self.training else "validation" self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation - self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group" + self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path" self.sampler_order = cfg.dataset.sample_order self.sampler_shuffle = cfg.dataset.sample_shuffle