I did it.
This commit is contained in:
parent
0e621354e7
commit
190a917b3e
|
@ -73,13 +73,11 @@ In theory, demasking for the NAR's RVQ level 0 can also be applied to the remain
|
|||
* this isn't necessary as the model already has a strong enough relationship between the prompt, the prior levels, and the targeted level.
|
||||
* this is technically already offered with `cfg.model.experimental.token_dropout_rate` which mirrors masking, but experimentation has not been done to a large degree.
|
||||
|
||||
Unfortunately, this model does not seem to prove stable for longer utterances, and the following bandaid copes do not solve this:
|
||||
* training for longer durations
|
||||
* iterative inferencing (inference the first n seconds, then the next n seconds, etc.)
|
||||
* more demasking steps in inferencing
|
||||
* training explicitly for NAR level 0
|
||||
|
||||
Despite being trained "flawlessly" (without any implementation issues), it seems to still exhibit the same issues as if it were trained erroneously (to predict the next token rather than the token in place).
|
||||
It is ***crucial*** to:
|
||||
* avoid re-masking tokens that are already "good" enough (this can easily be done by "banning" them in the scoring process)
|
||||
* without this, you ***will*** get stuttering and unaligned utterances. I do not know why this is such a big problem but I imagine this "interleaves" many different sequences between each step.
|
||||
* use unfiltered/unprocessed logit scores:
|
||||
* not that crucial, but helps stability.
|
||||
|
||||
## Embeddings (and Classifiers)
|
||||
|
||||
|
|
|
@ -790,6 +790,7 @@ class Config(BaseConfig):
|
|||
sample_rate: int = 24_000 # sample rate the model expects
|
||||
audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac""
|
||||
|
||||
weights_name: str = "fp32"
|
||||
weights_format: str = "sft" # "pth" | "sft"
|
||||
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
|
||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
load_path = pick_path( cfg.ckpt_dir / name / f"{cfg.weights_name}.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
|
||||
# actually use the lora-specific checkpoint if available
|
||||
if cfg.lora is not None:
|
||||
|
|
|
@ -341,7 +341,7 @@ class Engines(dict[str, Engine]):
|
|||
for name, engine in self.items():
|
||||
module = engine.module.state_dict()
|
||||
lora = None
|
||||
save_path = cfg.ckpt_dir / name / f"fp32.{format}"
|
||||
save_path = cfg.ckpt_dir / name / f"{cfg.weights_name}.{format}"
|
||||
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
|
||||
|
||||
# safety
|
||||
|
@ -350,7 +350,7 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
if cfg.lora is not None:
|
||||
lora, module = lora_get_state_dict( module, split = True )
|
||||
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"fp32.{format}"
|
||||
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"{cfg.weights_name}.{format}"
|
||||
|
||||
state_dict = {
|
||||
'module': module,
|
||||
|
|
|
@ -278,23 +278,6 @@ class AR_NAR(Base):
|
|||
null_prom = [ None for _ in range(batch_size) ]
|
||||
prev_list = resps_list
|
||||
|
||||
# to-do: only do the Nth first tokens, then the Nth seconds tokens, etc. until the last window
|
||||
# because for longer utterances it absolutely degrades
|
||||
"""
|
||||
buckets = max([ l // 75 for l in len_list ])
|
||||
original_len_list = [ l for l in len_list ]
|
||||
for bucket in range(1,buckets+1):
|
||||
len_list = [ int(l * bucket / buckets) for l in original_len_list ]
|
||||
if bucket == 1:
|
||||
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
|
||||
else:
|
||||
start_noise = 0.5
|
||||
resps_list = [ torch.concat([ resps, torch.ones((seq_len - resps.shape[0],), dtype=torch.int16, device=device) * self.stop_token ]) for resps, seq_len in zip( resps_list, len_list ) ]
|
||||
|
||||
scores = [ torch.zeros((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||
prev_list = resps_list
|
||||
"""
|
||||
|
||||
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
|
||||
# ramp down over time
|
||||
annealing = 1.0 - timestep
|
||||
|
@ -310,12 +293,9 @@ class AR_NAR(Base):
|
|||
# timestep inputs
|
||||
time_list = [ timestep for _ in range(batch_size) ]
|
||||
|
||||
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
|
||||
sampling_temperature = temperature * annealing
|
||||
sampling_cfg = cfg_strength * timestep
|
||||
"""
|
||||
sampling_temperature = temperature
|
||||
sampling_cfg = cfg_strength
|
||||
"""
|
||||
|
||||
# setup inputs
|
||||
inputs = super().inputs(
|
||||
|
@ -364,7 +344,6 @@ class AR_NAR(Base):
|
|||
)
|
||||
|
||||
# retrieves unfiltered logits
|
||||
"""
|
||||
unfiltered_sampled = super().sample(
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
|
@ -373,29 +352,21 @@ class AR_NAR(Base):
|
|||
temperature=0.0,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
"""
|
||||
# update previous list of tokens
|
||||
prev_list = resps_list
|
||||
# get sampled tokens
|
||||
sampled_ids = filtered_sampled.ids
|
||||
# keep unmasked tokens
|
||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||
# get probability scores (conjugate to have worse scoring tokens picked for topk)
|
||||
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
|
||||
|
||||
"""
|
||||
# maskgct does some funny stuff but it doesn't amount to anything
|
||||
if annealing < 1.0e-3:
|
||||
sampled_ids = filtered_sampled.ids
|
||||
else:
|
||||
sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits ]
|
||||
|
||||
# keep unmasked tokens
|
||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||
# update scores (conjugated to put the worst scores at the top)
|
||||
scores = [ torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
|
||||
scores = [ 1.0 - (choice_temperature * annealing * gumbel_noise( score ) + score) for score in scores ]
|
||||
"""
|
||||
# get probability scores
|
||||
scores = [
|
||||
# conjugate to have worse scoring tokens picked for topk
|
||||
1.0 -
|
||||
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
|
||||
torch.where( masked, torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked.shape, device=device) )
|
||||
# use unmodified logit scores for this, as it offers better stability
|
||||
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
|
||||
]
|
||||
|
||||
return resps_list
|
||||
|
||||
|
|
|
@ -1588,7 +1588,6 @@ class Base(nn.Module):
|
|||
mask = torch.cat([mask, padding], dim=1)
|
||||
|
||||
# needs to be done here as we still have our raw inputs
|
||||
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user