diff --git a/data/config.yaml b/data/config.yaml index 1055cbc..65ec089 100755 --- a/data/config.yaml +++ b/data/config.yaml @@ -20,6 +20,7 @@ dataset: prompt_duration: 3.0 sample_type: speaker + tasks_list: ["tts"] # do NOT change this until you're ready to train for SpeechX tasks # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] models: _models: @@ -102,6 +103,7 @@ trainer: inference: use_vocos: True + normalize: False # do NOT change this unless you know exactly what you are doing. bitsandbytes: enabled: false diff --git a/vall_e/config.py b/vall_e/config.py index 3ff8a66..2057007 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -132,6 +132,8 @@ class Dataset: prompt_duration: float = 3.0 sample_type: str = "path" # path | speaker + + tasks_list: list[str] = field(default_factory=lambda: ["tts"]) @dataclass() class Model: @@ -475,7 +477,8 @@ try: if not cfg.dataset.use_hdf5: cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ] cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ] - cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ] + + cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ] except Exception as e: pass diff --git a/vall_e/data.py b/vall_e/data.py index e455723..2d54080 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -204,7 +204,7 @@ class Dataset(_Dataset): @cached_property def tasks(self): - return ["tts"] # "ns", "sr", "tse", "cse", "nse" + return cfg.dataset.tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse" def _get_phone_symmap(self): return get_phone_symmap() @@ -216,16 +216,22 @@ class Dataset(_Dataset): return get_task_symmap() def get_task_token( self, token ): + if not hasattr(self, "task_symmap"): + self.task_symmap = self._get_task_symmap() return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(cfg.models.prom_levels) ]]).to(dtype=torch.int16) - def sample_noise(self): + def sample_noise(self): paths = [] - print(cfg.dataset.noise) for data_dir in cfg.dataset.noise: paths.extend(data_dir.rglob("*.qnt.pt")) - path = random.choice(paths) - return _load_quants(path) + + if False and cfg.dataset.use_hdf5: + key = f'/noise/{_get_hdf5_path(path)}' + qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) + else: + qnt = _load_quants(path) + return qnt def sample_speakers(self, ignore=[]): choices = set(self.spkrs) - set(ignore) @@ -305,7 +311,7 @@ class Dataset(_Dataset): # extend the noise to fill the target audio noise = repeat_extend_audio(noise, resps.shape[0]) # create the input prompt by merging the target audio with the noise - proms = merge_audio(resps, noise, scale=[1, 0.125]) + proms = merge_audio(resps, noise, scale=[1, 0.125], device="cpu") # set the target to just be the noise if if task == "sr": resps = noise @@ -325,10 +331,10 @@ class Dataset(_Dataset): smallest_size = min(resps.shape[0], other_proms.shape[0]) if other_proms.shape[0] == smallest_size: - noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)] ) + noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device="cpu" ) noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] ) else: - noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)] ) + noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device="cpu" ) noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] ) # stitch together the promps