added option to set probability of selecting the AR during training under a monolithic AR+NAR, added some more to-dos while I have them in mind

This commit is contained in:
mrq 2023-10-02 16:52:42 -05:00
parent e85b798fbf
commit d12877ee09
4 changed files with 25 additions and 12 deletions

View File

@ -147,6 +147,14 @@ And some experimental sampling flags you can use too (your mileage will ***defin
* train and release a ***good*** model.
* clean up the README, and document, document, document onto the wiki.
* extend to multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
* improve throughput:
- properly utilize RetNet's recurrent forward / chunkwise forward passes
- utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens
+ this requires a properly trained AR, however.
* work around issues with extending context past what's trained (despite RetNet's retention allegedly being able to defeat this):
- "sliding" AR input, such as have the context a fixed length.
+ may require additional training to be aware of this, might not.
+ may require some phoneme/codec alignment, might not.
## Notices and Citations

View File

@ -156,17 +156,18 @@ class Dataset:
@dataclass()
class Model:
name: str = ""
version: int = 1
size: str | float | dict = "full"
resp_levels: int = 1
prom_levels: int = 8
name: str = "" # vanity name for the model
version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding
size: str | dict = "full" # preset string or explicitly defined dimensionality
resp_levels: int = 1 # RVQ-bin levels this model targets for outputs
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
tasks: int = 0 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
langs: int = 0 # defined languages
arch_type: str = "retnet"
training: bool = True
interleave: bool = False
frozen_params: list[str] = field(default_factory=lambda: [])
arch_type: str = "retnet" # or "transformer""
training: bool = True # unneeded now
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
@property
def full_name(self):

View File

@ -291,8 +291,8 @@ class Dataset(_Dataset):
# shuffle it up a bit
prom_length = 0
#trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
#trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
for _ in range(cfg.dataset.max_prompts):
path = random.choice(choices)

View File

@ -106,7 +106,11 @@ class AR_NAR(Base):
# is training
if n_levels == self.n_resp_levels:
if cfg.models.ar_nar.p_ar_level == "auto":
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
else:
quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ])
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # yes I can just do min(1, l)
quant_levels.to(device=device)