diff --git a/README.md b/README.md index ec570e7..c4e2fc7 100755 --- a/README.md +++ b/README.md @@ -147,6 +147,14 @@ And some experimental sampling flags you can use too (your mileage will ***defin * train and release a ***good*** model. * clean up the README, and document, document, document onto the wiki. * extend to multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)). +* improve throughput: + - properly utilize RetNet's recurrent forward / chunkwise forward passes + - utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens + + this requires a properly trained AR, however. +* work around issues with extending context past what's trained (despite RetNet's retention allegedly being able to defeat this): + - "sliding" AR input, such as have the context a fixed length. + + may require additional training to be aware of this, might not. + + may require some phoneme/codec alignment, might not. ## Notices and Citations diff --git a/vall_e/config.py b/vall_e/config.py index 058cf24..60e05d4 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -156,17 +156,18 @@ class Dataset: @dataclass() class Model: - name: str = "" - version: int = 1 - size: str | float | dict = "full" - resp_levels: int = 1 - prom_levels: int = 8 + name: str = "" # vanity name for the model + version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding + size: str | dict = "full" # preset string or explicitly defined dimensionality + resp_levels: int = 1 # RVQ-bin levels this model targets for outputs + prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt tasks: int = 0 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") langs: int = 0 # defined languages - arch_type: str = "retnet" - training: bool = True - interleave: bool = False - frozen_params: list[str] = field(default_factory=lambda: []) + arch_type: str = "retnet" # or "transformer"" + training: bool = True # unneeded now + interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results) + p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior + frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training @property def full_name(self): diff --git a/vall_e/data.py b/vall_e/data.py index 9e36c51..ea9a3a1 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -291,8 +291,8 @@ class Dataset(_Dataset): # shuffle it up a bit prom_length = 0 - #trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds] - trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75) + trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds] + #trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75) for _ in range(cfg.dataset.max_prompts): path = random.choice(choices) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index f146934..53a3793 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -106,7 +106,11 @@ class AR_NAR(Base): # is training if n_levels == self.n_resp_levels: - quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + if cfg.models.ar_nar.p_ar_level == "auto": + quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + else: + quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ]) + targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target) resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # yes I can just do min(1, l) quant_levels.to(device=device)