resetting step count resets the samples processed and other metrics

master
mrq 2023-10-29 12:11:19 +07:00
parent 0aa2a3cc07
commit 6c51a629cc
2 changed files with 7 additions and 3 deletions

@ -147,16 +147,17 @@ And some experimental sampling flags you can use too (your mileage will ***defin
## To-Do
* train and release a ***good*** model.
- the current model seems to require a ***long*** time of training at a very small LR rate to try and cover a wide variety of speakers of varying acoustics.
* clean up the README, and document, document, document onto the wiki.
* extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
- training additional tasks needs the SpeechX implementation to be reworked.
* improve throughput (despite peaking at 120it/s):
- properly utilize RetNet's recurrent forward / chunkwise forward passes
- properly utilize RetNet's recurrent forward / chunkwise forward passes (does not seem to want to work no matter how the model is trained).
- utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens
+ this requires a properly trained AR, however.
* work around issues with extending context past what's trained (despite RetNet's retention allegedly being able to defeat this):
- "sliding" AR input, such as have the context a fixed length.
+ may require additional training to be aware of this, might not.
+ may require some phoneme/codec alignment, might not.
+ the model may need to be trained for this with a fancy positional embedding injected. Naively sliding the context window while making use of the RetNet implementation's positional embedding doesn't seem fruitful.
## Notices and Citations

@ -321,6 +321,9 @@ class Engines(dict[str, Engine]):
)
if cfg.trainer.restart_step_count:
engine.global_steps = 0
engine.mocro_step = 0
engine.global_samples = 0
engine.tokens_processed = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
if cfg.hyperparameters.scheduler_type == "":