I did it.

This commit is contained in:
mrq 2024-11-19 12:24:33 -06:00
parent 0e621354e7
commit 190a917b3e
6 changed files with 19 additions and 50 deletions

View File

@ -73,13 +73,11 @@ In theory, demasking for the NAR's RVQ level 0 can also be applied to the remain
* this isn't necessary as the model already has a strong enough relationship between the prompt, the prior levels, and the targeted level.
* this is technically already offered with `cfg.model.experimental.token_dropout_rate` which mirrors masking, but experimentation has not been done to a large degree.
Unfortunately, this model does not seem to prove stable for longer utterances, and the following bandaid copes do not solve this:
* training for longer durations
* iterative inferencing (inference the first n seconds, then the next n seconds, etc.)
* more demasking steps in inferencing
* training explicitly for NAR level 0
Despite being trained "flawlessly" (without any implementation issues), it seems to still exhibit the same issues as if it were trained erroneously (to predict the next token rather than the token in place).
It is ***crucial*** to:
* avoid re-masking tokens that are already "good" enough (this can easily be done by "banning" them in the scoring process)
* without this, you ***will*** get stuttering and unaligned utterances. I do not know why this is such a big problem but I imagine this "interleaves" many different sequences between each step.
* use unfiltered/unprocessed logit scores:
* not that crucial, but helps stability.
## Embeddings (and Classifiers)

View File

@ -790,6 +790,7 @@ class Config(BaseConfig):
sample_rate: int = 24_000 # sample rate the model expects
audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac""
weights_name: str = "fp32"
weights_format: str = "sft" # "pth" | "sft"
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])

View File

@ -46,7 +46,7 @@ def load_engines(training=True, **model_kwargs):
checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
load_path = pick_path( cfg.ckpt_dir / name / f"{cfg.weights_name}.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
# actually use the lora-specific checkpoint if available
if cfg.lora is not None:

View File

@ -341,7 +341,7 @@ class Engines(dict[str, Engine]):
for name, engine in self.items():
module = engine.module.state_dict()
lora = None
save_path = cfg.ckpt_dir / name / f"fp32.{format}"
save_path = cfg.ckpt_dir / name / f"{cfg.weights_name}.{format}"
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
# safety
@ -350,7 +350,7 @@ class Engines(dict[str, Engine]):
if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True )
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"fp32.{format}"
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"{cfg.weights_name}.{format}"
state_dict = {
'module': module,

View File

@ -278,23 +278,6 @@ class AR_NAR(Base):
null_prom = [ None for _ in range(batch_size) ]
prev_list = resps_list
# to-do: only do the Nth first tokens, then the Nth seconds tokens, etc. until the last window
# because for longer utterances it absolutely degrades
"""
buckets = max([ l // 75 for l in len_list ])
original_len_list = [ l for l in len_list ]
for bucket in range(1,buckets+1):
len_list = [ int(l * bucket / buckets) for l in original_len_list ]
if bucket == 1:
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
else:
start_noise = 0.5
resps_list = [ torch.concat([ resps, torch.ones((seq_len - resps.shape[0],), dtype=torch.int16, device=device) * self.stop_token ]) for resps, seq_len in zip( resps_list, len_list ) ]
scores = [ torch.zeros((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
prev_list = resps_list
"""
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
# ramp down over time
annealing = 1.0 - timestep
@ -310,12 +293,9 @@ class AR_NAR(Base):
# timestep inputs
time_list = [ timestep for _ in range(batch_size) ]
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
sampling_temperature = temperature * annealing
sampling_cfg = cfg_strength * timestep
"""
sampling_temperature = temperature
sampling_cfg = cfg_strength
"""
# setup inputs
inputs = super().inputs(
@ -364,7 +344,6 @@ class AR_NAR(Base):
)
# retrieves unfiltered logits
"""
unfiltered_sampled = super().sample(
logits=logits,
prev_list=prev_list,
@ -373,29 +352,21 @@ class AR_NAR(Base):
temperature=0.0,
**sampling_kwargs,
)
"""
# update previous list of tokens
prev_list = resps_list
# get sampled tokens
sampled_ids = filtered_sampled.ids
# keep unmasked tokens
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
# get probability scores (conjugate to have worse scoring tokens picked for topk)
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
"""
# maskgct does some funny stuff but it doesn't amount to anything
if annealing < 1.0e-3:
sampled_ids = filtered_sampled.ids
else:
sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits ]
# keep unmasked tokens
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
# update scores (conjugated to put the worst scores at the top)
scores = [ torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
scores = [ 1.0 - (choice_temperature * annealing * gumbel_noise( score ) + score) for score in scores ]
"""
# get probability scores
scores = [
# conjugate to have worse scoring tokens picked for topk
1.0 -
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
torch.where( masked, torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked.shape, device=device) )
# use unmodified logit scores for this, as it offers better stability
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
]
return resps_list

View File

@ -1588,7 +1588,6 @@ class Base(nn.Module):
mask = torch.cat([mask, padding], dim=1)
# needs to be done here as we still have our raw inputs
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
classifier_levels = self.get_input( inputs, name="classifier_level" )