diff --git a/README.md b/README.md index adf9d71..d423c09 100755 --- a/README.md +++ b/README.md @@ -147,16 +147,17 @@ And some experimental sampling flags you can use too (your mileage will ***defin ## To-Do * train and release a ***good*** model. + - the current model seems to require a ***long*** time of training at a very small LR rate to try and cover a wide variety of speakers of varying acoustics. * clean up the README, and document, document, document onto the wiki. * extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)). + - training additional tasks needs the SpeechX implementation to be reworked. * improve throughput (despite peaking at 120it/s): - - properly utilize RetNet's recurrent forward / chunkwise forward passes + - properly utilize RetNet's recurrent forward / chunkwise forward passes (does not seem to want to work no matter how the model is trained). - utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens + this requires a properly trained AR, however. * work around issues with extending context past what's trained (despite RetNet's retention allegedly being able to defeat this): - "sliding" AR input, such as have the context a fixed length. - + may require additional training to be aware of this, might not. - + may require some phoneme/codec alignment, might not. + + the model may need to be trained for this with a fancy positional embedding injected. Naively sliding the context window while making use of the RetNet implementation's positional embedding doesn't seem fruitful. ## Notices and Citations diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index f2949db..222d0af 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -321,6 +321,9 @@ class Engines(dict[str, Engine]): ) if cfg.trainer.restart_step_count: engine.global_steps = 0 + engine.mocro_step = 0 + engine.global_samples = 0 + engine.tokens_processed = 0 # update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler if cfg.hyperparameters.scheduler_type == "":