two weeks of agony concludes
This commit is contained in:
parent
2b29790173
commit
5ba80686e1
|
@ -73,6 +73,14 @@ In theory, demasking for the NAR's RVQ level 0 can also be applied to the remain
|
|||
* this isn't necessary as the model already has a strong enough relationship between the prompt, the prior levels, and the targeted level.
|
||||
* this is technically already offered with `cfg.model.experimental.token_dropout_rate` which mirrors masking, but experimentation has not been done to a large degree.
|
||||
|
||||
Unfortunately, this model does not seem to prove stable for longer utterances, and the following bandaid copes do not solve this:
|
||||
* training for longer durations
|
||||
* iterative inferencing (inference the first n seconds, then the next n seconds, etc.)
|
||||
* more demasking steps in inferencing
|
||||
* training explicitly for NAR level 0
|
||||
|
||||
Despite being trained "flawlessly" (without any implementation issues), it seems to still exhibit the same issues as if it were trained erroneously (to predict the next token rather than the token in place).
|
||||
|
||||
## Embeddings (and Classifiers)
|
||||
|
||||
The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can get away with treating each sequence as separate ranges within a total token sequence.
|
||||
|
|
|
@ -131,7 +131,7 @@ def main():
|
|||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--module-only", action='store_true')
|
||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
||||
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
|
||||
parser.add_argument("--export-lora", action='store_true', default=None) # exports LoRA
|
||||
parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads
|
||||
parser.add_argument("--moe-ify", action='store_true', default=None) # splits classifier heads
|
||||
parser.add_argument("--experts", type=int, default=8) # set target dtype to export to
|
||||
|
@ -146,7 +146,7 @@ def main():
|
|||
cfg.trainer.load_module_only = True
|
||||
|
||||
|
||||
if args.hf and args.lora:
|
||||
if args.hf and args.export_lora:
|
||||
raise Exception("Requesting more than one callback")
|
||||
|
||||
if args.dtype != "auto":
|
||||
|
@ -160,7 +160,7 @@ def main():
|
|||
callback = None
|
||||
if args.hf:
|
||||
callback = convert_to_hf
|
||||
elif args.lora:
|
||||
elif args.export_lora:
|
||||
callback = extract_lora
|
||||
elif args.split_classifiers:
|
||||
callback = split_classifier_heads
|
||||
|
|
|
@ -206,11 +206,11 @@ class TTS():
|
|||
model_nar = None
|
||||
|
||||
for name, engine in self.engines.items():
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
if model_ar is None and "ar" in engine.hyper_config.capabilities:
|
||||
model_ar = engine.module
|
||||
if "len" in engine.hyper_config.capabilities:
|
||||
if model_len is None and "len" in engine.hyper_config.capabilities:
|
||||
model_len = engine.module
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
if model_nar is None and "nar" in engine.hyper_config.capabilities:
|
||||
model_nar = engine.module
|
||||
|
||||
seed = set_seed(seed)
|
||||
|
|
|
@ -258,6 +258,7 @@ class AR_NAR(Base):
|
|||
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
|
||||
|
||||
scores = [ torch.zeros((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||
|
||||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||
null_prom = [ None for _ in range(batch_size) ]
|
||||
|
@ -265,6 +266,21 @@ class AR_NAR(Base):
|
|||
|
||||
# to-do: only do the Nth first tokens, then the Nth seconds tokens, etc. until the last window
|
||||
# because for longer utterances it absolutely degrades
|
||||
"""
|
||||
buckets = max([ l // 75 for l in len_list ])
|
||||
original_len_list = [ l for l in len_list ]
|
||||
for bucket in range(1,buckets+1):
|
||||
len_list = [ int(l * bucket / buckets) for l in original_len_list ]
|
||||
if bucket == 1:
|
||||
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
|
||||
else:
|
||||
start_noise = 0.5
|
||||
resps_list = [ torch.concat([ resps, torch.ones((seq_len - resps.shape[0],), dtype=torch.int16, device=device) * self.stop_token ]) for resps, seq_len in zip( resps_list, len_list ) ]
|
||||
|
||||
scores = [ torch.zeros((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||
prev_list = resps_list
|
||||
"""
|
||||
|
||||
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
|
||||
# ramp down over time
|
||||
annealing = 1.0 - timestep
|
||||
|
|
Loading…
Reference in New Issue
Block a user