two weeks of agony concludes

This commit is contained in:
mrq 2024-11-18 21:29:28 -06:00
parent 2b29790173
commit 5ba80686e1
4 changed files with 30 additions and 6 deletions

View File

@ -73,6 +73,14 @@ In theory, demasking for the NAR's RVQ level 0 can also be applied to the remain
* this isn't necessary as the model already has a strong enough relationship between the prompt, the prior levels, and the targeted level.
* this is technically already offered with `cfg.model.experimental.token_dropout_rate` which mirrors masking, but experimentation has not been done to a large degree.
Unfortunately, this model does not seem to prove stable for longer utterances, and the following bandaid copes do not solve this:
* training for longer durations
* iterative inferencing (inference the first n seconds, then the next n seconds, etc.)
* more demasking steps in inferencing
* training explicitly for NAR level 0
Despite being trained "flawlessly" (without any implementation issues), it seems to still exhibit the same issues as if it were trained erroneously (to predict the next token rather than the token in place).
## Embeddings (and Classifiers)
The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can get away with treating each sequence as separate ranges within a total token sequence.

View File

@ -131,7 +131,7 @@ def main():
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--module-only", action='store_true')
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
parser.add_argument("--export-lora", action='store_true', default=None) # exports LoRA
parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads
parser.add_argument("--moe-ify", action='store_true', default=None) # splits classifier heads
parser.add_argument("--experts", type=int, default=8) # set target dtype to export to
@ -146,7 +146,7 @@ def main():
cfg.trainer.load_module_only = True
if args.hf and args.lora:
if args.hf and args.export_lora:
raise Exception("Requesting more than one callback")
if args.dtype != "auto":
@ -160,7 +160,7 @@ def main():
callback = None
if args.hf:
callback = convert_to_hf
elif args.lora:
elif args.export_lora:
callback = extract_lora
elif args.split_classifiers:
callback = split_classifier_heads

View File

@ -206,11 +206,11 @@ class TTS():
model_nar = None
for name, engine in self.engines.items():
if "ar" in engine.hyper_config.capabilities:
if model_ar is None and "ar" in engine.hyper_config.capabilities:
model_ar = engine.module
if "len" in engine.hyper_config.capabilities:
if model_len is None and "len" in engine.hyper_config.capabilities:
model_len = engine.module
if "nar" in engine.hyper_config.capabilities:
if model_nar is None and "nar" in engine.hyper_config.capabilities:
model_nar = engine.module
seed = set_seed(seed)

View File

@ -258,6 +258,7 @@ class AR_NAR(Base):
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
scores = [ torch.zeros((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
quant_levels = [ level for _ in range(batch_size) ]
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
@ -265,6 +266,21 @@ class AR_NAR(Base):
# to-do: only do the Nth first tokens, then the Nth seconds tokens, etc. until the last window
# because for longer utterances it absolutely degrades
"""
buckets = max([ l // 75 for l in len_list ])
original_len_list = [ l for l in len_list ]
for bucket in range(1,buckets+1):
len_list = [ int(l * bucket / buckets) for l in original_len_list ]
if bucket == 1:
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
else:
start_noise = 0.5
resps_list = [ torch.concat([ resps, torch.ones((seq_len - resps.shape[0],), dtype=torch.int16, device=device) * self.stop_token ]) for resps, seq_len in zip( resps_list, len_list ) ]
scores = [ torch.zeros((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
prev_list = resps_list
"""
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
# ramp down over time
annealing = 1.0 - timestep