diff --git a/vall_e/config.py b/vall_e/config.py index ced46cb..657e64a 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -167,6 +167,7 @@ class Dataset: reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method noise_scale: float = 0.25 # scaling noise value + inject_noise_in_prom: bool = False _frames_per_second: int = 0 # allows setting your own hint @@ -358,6 +359,7 @@ class LoRA: rank: int = 128 # rank for the LoRA alpha: int = 128 # rank for the LoRA training: bool = True # + embeddings: bool = False # train the embedding too parametrize: bool = False # rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA @@ -832,8 +834,7 @@ class Config(BaseConfig): if self.hyperparameters.scheduler == "": self.hyperparameters.torch_scheduler = True - if self.dataset.prompt_duration != 0: - self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration] + self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration] if self.trainer.backend == "local" and self.distributed: self.trainer.ddp = True diff --git a/vall_e/data.py b/vall_e/data.py index 1b33994..5ec4c6a 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -11,7 +11,7 @@ import torch import itertools from .config import cfg -from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file +from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler from .utils.distributed import global_rank, local_rank, world_size @@ -717,6 +717,9 @@ class Dataset(_Dataset): def sample_prompts(self, spkr_name, ignore, should_trim=True): + if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0: + return None + prom_list = [] choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore} @@ -748,7 +751,7 @@ class Dataset(_Dataset): qnt = _load_quants(path, return_metadata=False) if 0 < trim_length and trim_length < qnt.shape[0]: - qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat ) + qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device ) prom_list.append(qnt) prom_length += qnt.shape[0] @@ -758,10 +761,10 @@ class Dataset(_Dataset): # might be better to decode => concat waveforms with silence in between => reencode # as you technically can't just append encodec sequences together like this without issues - prom = torch.cat(prom_list) + prom = concat_audio( *prom_list, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device ) if 0 < trim_length and trim_length < prom.shape[0]: - prom = trim( prom, trim_length, reencode=cfg.dataset.reencode_on_concat ) + prom = trim( prom, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device ) return prom @@ -855,6 +858,15 @@ class Dataset(_Dataset): if task == "tts": proms = self.sample_prompts(spkr_name, ignore=path) + if cfg.dataset.inject_noise_in_prom: + # sample random noise + noise = self.sample_noise() + # extend the noise to fill the target audio + noise = repeat_extend_audio(noise, proms.shape[0]) + # create the input prompt by merging the target audio with the noise + proms = merge_audio( proms, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device ) + + # VALL-E Continuous ( => ) # (this could just be sampled as