Commit Graph

651 Commits

Author SHA1 Message Date
mrq
41d7c30ea5 added much cleaner non-causal mask generation 2024-11-22 19:43:32 -06:00
mrq
c99a74e834 actually generate a causal mask because it seems sometimes it does not actually generate one because it makes assumptions 2024-11-22 18:30:24 -06:00
mrq
ccee5fc11c that was actually all pointless since sdpa always had an attention mask fed to it and does not need is_causal to implicitly generate one 2024-11-22 16:51:50 -06:00
mrq
4aa685e749 what has science done 2024-11-22 16:45:40 -06:00
mrq
147219a5e0 huge oversight in the attention masking......... (i realized I have not been providing a non-causal mask to non-causal tasks) 2024-11-22 13:44:43 -06:00
mrq
24d888c47c temporarily dropping support for xformers because it's breaking when using an attention mask (which i dont remember commenting it out when being passed), default to not use wandb because it's being a pain when doing tests and not actual sessionsS) 2024-11-22 11:29:12 -06:00
mrq
8aafae91fd dont use timeembedding 2024-11-21 23:14:52 -06:00
mrq
2cef97e43f cleanup 2024-11-21 23:08:43 -06:00
mrq
3fc0540f49 m 2024-11-21 15:07:46 -06:00
mrq
6845c447c9 added more harvard sentences to load from a text file 2024-11-21 13:18:11 -06:00
mrq
2a084544e8 moved duration padding for NAR-len to be a scalar instead (since it seems longer utterances need it much more so than shorter utterances) 2024-11-21 13:04:07 -06:00
mrq
6aee08f9c0 moved stuff in the web UI around (un-experimented the max NAR-len steps because its kind of important to adjust this value for better sounding audio / quicker generated audio) 2024-11-20 20:37:33 -06:00
mrq
dfdba3f190 oops 2024-11-20 19:21:03 -06:00
mrq
cd6e9ba2f2 oops 2024-11-20 16:27:51 -06:00
mrq
1a73ac6a20 I cannot believe it's not actually called Wand DB (added wandb logging support since I think it would have been a much better way to look at my metrics) 2024-11-20 16:10:47 -06:00
mrq
67f7bad168 added mixed modality AR+NAR-len to generate a short prefix through the AR, then inference with said prefix through the NAR-len (need to experiment with it more to ensure that the masked off tokens are the only tokens getting updated) 2024-11-20 14:22:12 -06:00
mrq
db64e6cb59 dependency updates (gradio 5.x now works on my machine) 2024-11-20 12:33:01 -06:00
mrq
b1369e7824 better modality selection (pick AR+NAR by default for the ar+nar model, pick NAR-len by default for the nar-len model), lowered default CFG because it makes the AR+NAR output sped up (but can't be too low since it's required for the NAR-len) 2024-11-19 18:51:17 -06:00
mrq
190a917b3e I did it. 2024-11-19 12:24:33 -06:00
mrq
0e621354e7 cleaned up classifier-free guidance logit processing (in order to try and cope with a bad nar-len model) 2024-11-19 10:30:05 -06:00
mrq
5ba80686e1 two weeks of agony concludes 2024-11-18 21:29:28 -06:00
mrq
2b29790173 oops 2024-11-18 14:12:26 -06:00
mrq
4a71981456 normalize sampler index by batch size (if not using batched sampler), add option to cap out utterances for a speaker, some other things 2024-11-18 12:46:50 -06:00
mrq
6cfdf94bf9 swap priority to use nar-len if available, added notes 2024-11-18 09:40:04 -06:00
mrq
069b27570f set option to set training masking ratio (I don't think for tts a fixed masking ratio is beneficial since the magic of the AR+NAR is being able to still reference the prior sequence of tokens for predicting things) 2024-11-17 17:04:07 -06:00
mrq
88d840218d default set cfg strength to 3.0 since the reference model is updated 2024-11-17 10:23:40 -06:00
mrq
a3e1fa3518 ugh 2024-11-17 09:28:33 -06:00
mrq
23fdba0c98 tweaks and changes 2024-11-16 15:49:06 -06:00
mrq
2fbeacfe92 ugh 2024-11-14 22:18:33 -06:00
mrq
39096f8ff3 redid loss calculation to be cleaner, and position ID generation, and other things (I might need to train the NAR-len from scratch and not resume from an existing checkpoint.........) 2024-11-14 22:17:47 -06:00
mrq
ef05c951ff adjust fp16 loss scaling since I fried a model overnight when it hit 8K scale 2024-11-14 09:23:52 -06:00
mrq
e412e98125 ugh 2024-11-14 07:34:22 -06:00
mrq
c00fc18b62 actually use the right embedding for nar-len 2024-11-13 18:04:04 -06:00
mrq
3ea8a610d6 fix STT 2024-11-13 14:27:15 -06:00
mrq
910033343c overhauled how the right resp level / classifier gets picked to avoid cringemath 2024-11-13 13:31:17 -06:00
mrq
269648605e move NAR-len rvq level 0 to separate embedding 2024-11-13 11:38:58 -06:00
mrq
29e45be0b4 tweaks to bucket sampling 2024-11-13 11:09:24 -06:00
mrq
b2eca271a8 ugh 2024-11-13 10:35:44 -06:00
mrq
be83ddabaa better causal-ness for split loss calc, and also do masking for NAR-len for it 2024-11-13 10:17:52 -06:00
mrq
6b76419123 ugh 2024-11-13 09:54:20 -06:00
mrq
ad7cfffc00 NAR-len RVQ-0 was being trained causally............. 2024-11-13 09:43:50 -06:00
mrq
976ee87f6f resume iteration step in tqdm trainer, warn to logger if the sampler state dict was invalidated 2024-11-13 09:09:28 -06:00
mrq
8286aa54c8 do not pass timestep token/embedding since it doesn't seem to matter at all after all, fixed training masking rate to 80% because a paper said so 2024-11-13 09:07:10 -06:00
mrq
caf721c67b set it to zero because it'll make the stop token hide more often than not 2024-11-12 22:30:50 -06:00
mrq
0f2584eba7 new meme sampler PogChamp new meme sampler PogChamp (it sort of helps?) 2024-11-12 22:30:09 -06:00
mrq
663f07038d haha... (do not create a token dropout/noise mask when not training (this sadly didnt fix NAR-len output)) 2024-11-12 16:41:58 -06:00
mrq
b09328069e actually do CFG sampling for base AR+NAR tasks 2024-11-12 13:42:39 -06:00
mrq
2495a7ef67 Fixed STT in the web UI 2024-11-12 12:49:53 -06:00
mrq
8927bad7bc actually fixed rep pen (for ar and nar, it seems to help with nar unmasking) 2024-11-11 21:40:19 -06:00
mrq
ec92613847 actually pass input prompt length size to inference 2024-11-11 20:39:48 -06:00
mrq
b1df6a7bed reverted rep pen sampler due to a regression 2024-11-11 20:35:08 -06:00
mrq
b1f4db39c8 threw in CFG sampling for normal model as well to experiment with 2024-11-11 20:27:38 -06:00
mrq
2f56696506 overhauled inference/sampler kwargs to stop being a bloated mess 2024-11-11 20:21:16 -06:00
mrq
354f8e059d store dataset hash alongside state dict so it can be ignored if mismatched 2024-11-11 18:16:56 -06:00
mrq
f7b8b1e825 dropped subtrain dataloader since its useless to duplicate 2024-11-11 17:00:49 -06:00
mrq
cf9df71f2c use homwbrewed caching system for dataloader paths / durations (I'm pretty sure I am now triggering OOM killers with my entire dataset used) 2024-11-11 16:32:08 -06:00
mrq
a748e223ce tweaks 2024-11-11 12:40:41 -06:00
mrq
48490757da fixes 2024-11-10 20:37:50 -06:00
mrq
9def34cd66 lol 2024-11-10 12:48:41 -06:00
mrq
9cb0b6901b unified nar.py into ar_nar.py 2024-11-10 12:19:48 -06:00
mrq
a9d2faf2d7 all I can do now until I wait for the model to (re)train for pure NAR 2024-11-09 22:57:34 -06:00
mrq
ad7e290a5e ugh (ROCm seems to silently clamp any token value >= logits.shape[-1] for loss calculation, while cuda will throw an assert, making it hard to find this dumb fuckup) 2024-11-09 19:40:02 -06:00
mrq
943fe70c10 I don't know why this fixes an assert thrown but it does 2024-11-09 19:04:13 -06:00
mrq
f50d92ba6c Almost made a mistake 2024-11-09 18:12:54 -06:00
mrq
c6a38693a2 This better work 2024-11-09 18:04:59 -06:00
mrq
8b3d1cf70a Something's Wrong 2024-11-09 15:07:43 -06:00
mrq
dcd5fecff3 some cleanup while I wait for the NAR-len to train to an acceptable state (currently it performs okay, but only on audo after 3 seconds or so) 2024-11-09 12:12:46 -06:00
mrq
69b0b3b854 set timestep tensor to whatever the time embedding's dtype is because it'll gripe under amp 2024-11-09 00:11:16 -06:00
mrq
5a09a5f6e9 I forgot about the time embedding... 2024-11-08 22:46:26 -06:00
mrq
811b15d280 I suppose I just have a shit training method since the sampler is as solid as I can get it............... 2024-11-08 22:05:41 -06:00
mrq
13b54953bd agony 2024-11-08 13:34:39 -06:00
mrq
c127c4e488 'borrowed' a sampling scheduler for NAR-len's RVQ level 0 (better than before, but still not good enough) 2024-11-07 21:19:14 -06:00
mrq
e108c54daf new NAR-len training paradigm...... 2024-11-07 11:32:11 -06:00
mrq
ed174c589e ugh 2024-11-07 09:19:21 -06:00
mrq
d13ab00ad8 one more note 2024-11-07 09:11:21 -06:00
mrq
5698188824 あたしって、ほんとバカ 2024-11-07 09:10:18 -06:00
mrq
77ff23e319 repeat extend the prom to fill the initial tokens for nar-len (it somewhat works, the model just needs to train more) 2024-11-06 23:29:53 -06:00
mrq
a3bc26f7ec ugh 2024-11-06 23:16:28 -06:00
mrq
d606a693ff eval fix for nar-len 2024-11-06 23:14:16 -06:00
mrq
105ed51159 I guess I'll fall for the NAR-len meme again (I don't know where my previous weights are, so I need to train it again to test something) 2024-11-06 19:17:12 -06:00
mrq
bcabde3454 more notes 2024-11-06 13:51:28 -06:00
mrq
bfc5e1d723 agony 2024-11-05 22:30:49 -06:00
mrq
aefe8fcdad UGH 2024-11-05 22:13:58 -06:00
mrq
556d9db0d5 web UI support for HF ZeroGPU 2024-11-05 21:38:02 -06:00
mrq
e58a9469a3 move layerskip to experimental settings....... 2024-11-05 20:37:06 -06:00
mrq
bbc2de3713 ugh 2024-11-05 11:50:05 -06:00
mrq
9e65e05e83 more windows specific fixes, limit gradio to <5.0.0 on linux (it works on windows, but not on my linux machine tm) 2024-11-04 18:00:33 -06:00
mrq
c83670c38c Windows specific fixes (to-do: find libespeak-ng.dll automatically because it cannot be trusted to do it by default) 2024-11-03 19:19:15 -06:00
mrq
d229725c76 more adjustments (adjustments of early-exit entropy/varentropy thresholds, default rep pen being 1.5, experimental refine-on-stop, etc.) 2024-11-03 18:31:28 -06:00
mrq
aee08b7307 changed layerskip float16 training warning (since it didnt seem to fry on my 4xV100 system) 2024-11-03 09:58:29 -06:00
mrq
3826f9bae4 saner mask creation? (it doesnt matter, kv cache wont work) 2024-11-02 21:00:21 -05:00
mrq
ded746e157 very, very naive layerskip speculative sampling (it just checks if the current layer's state is good enough) 2024-11-02 11:49:05 -05:00
mrq
62fe5b0943 ughh 2024-11-01 22:36:48 -05:00
mrq
ec79230965 shuffled web UI options hidden by cfg.experimental to its own tab, expose early exit selection to inferencing (it kinda works naively, still need to implement self-speculation) 2024-11-01 21:30:06 -05:00
mrq
ef1c17430f skip step on nan loss (ironically I have not had a nan loss after adding this), throw exception with invalid cfg.dataset.sample_type and sample_order combination (because I was tricked by this in my yaml and had inconsistent vram usage) 2024-11-01 20:54:53 -05:00
mrq
fb8faa295b actually float16(+AMP) and layerskip is bad and will kill the model...... 2024-11-01 18:36:44 -05:00
mrq
edf1e66bf9 layerskip_r=6 fries the model so hard the loss is sub-1... 2024-11-01 17:06:07 -05:00
mrq
9b6c57bc57 third time's the charm (for some reason it escaped me that I should treat early exit loss as an aux_loss to be used with the normal loss, as if I was training a MoE's router) 2024-11-01 12:50:37 -05:00
mrq
76ebef45dc off-by-one... 2024-10-31 13:24:48 -05:00
mrq
b63293cbbe ugh 2024-10-30 22:49:11 -05:00