added option to set probability of selecting the AR during training under a monolithic AR+NAR, added some more to-dos while I have them in mind
This commit is contained in:
parent
e85b798fbf
commit
d12877ee09
|
@ -147,6 +147,14 @@ And some experimental sampling flags you can use too (your mileage will ***defin
|
|||
* train and release a ***good*** model.
|
||||
* clean up the README, and document, document, document onto the wiki.
|
||||
* extend to multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
|
||||
* improve throughput:
|
||||
- properly utilize RetNet's recurrent forward / chunkwise forward passes
|
||||
- utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens
|
||||
+ this requires a properly trained AR, however.
|
||||
* work around issues with extending context past what's trained (despite RetNet's retention allegedly being able to defeat this):
|
||||
- "sliding" AR input, such as have the context a fixed length.
|
||||
+ may require additional training to be aware of this, might not.
|
||||
+ may require some phoneme/codec alignment, might not.
|
||||
|
||||
## Notices and Citations
|
||||
|
||||
|
|
|
@ -156,17 +156,18 @@ class Dataset:
|
|||
|
||||
@dataclass()
|
||||
class Model:
|
||||
name: str = ""
|
||||
version: int = 1
|
||||
size: str | float | dict = "full"
|
||||
resp_levels: int = 1
|
||||
prom_levels: int = 8
|
||||
name: str = "" # vanity name for the model
|
||||
version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding
|
||||
size: str | dict = "full" # preset string or explicitly defined dimensionality
|
||||
resp_levels: int = 1 # RVQ-bin levels this model targets for outputs
|
||||
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
|
||||
tasks: int = 0 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||
langs: int = 0 # defined languages
|
||||
arch_type: str = "retnet"
|
||||
training: bool = True
|
||||
interleave: bool = False
|
||||
frozen_params: list[str] = field(default_factory=lambda: [])
|
||||
arch_type: str = "retnet" # or "transformer""
|
||||
training: bool = True # unneeded now
|
||||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
|
||||
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
|
||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
|
|
|
@ -291,8 +291,8 @@ class Dataset(_Dataset):
|
|||
|
||||
# shuffle it up a bit
|
||||
prom_length = 0
|
||||
#trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||
trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
|
||||
#trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||
|
||||
for _ in range(cfg.dataset.max_prompts):
|
||||
path = random.choice(choices)
|
||||
|
|
|
@ -106,7 +106,11 @@ class AR_NAR(Base):
|
|||
|
||||
# is training
|
||||
if n_levels == self.n_resp_levels:
|
||||
if cfg.models.ar_nar.p_ar_level == "auto":
|
||||
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
else:
|
||||
quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ])
|
||||
|
||||
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
||||
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # yes I can just do min(1, l)
|
||||
quant_levels.to(device=device)
|
||||
|
|
Loading…
Reference in New Issue
Block a user