I did it.

This commit is contained in:
mrq 2024-11-19 12:24:33 -06:00
parent 0e621354e7
commit 190a917b3e
6 changed files with 19 additions and 50 deletions

View File

@ -73,13 +73,11 @@ In theory, demasking for the NAR's RVQ level 0 can also be applied to the remain
* this isn't necessary as the model already has a strong enough relationship between the prompt, the prior levels, and the targeted level. * this isn't necessary as the model already has a strong enough relationship between the prompt, the prior levels, and the targeted level.
* this is technically already offered with `cfg.model.experimental.token_dropout_rate` which mirrors masking, but experimentation has not been done to a large degree. * this is technically already offered with `cfg.model.experimental.token_dropout_rate` which mirrors masking, but experimentation has not been done to a large degree.
Unfortunately, this model does not seem to prove stable for longer utterances, and the following bandaid copes do not solve this: It is ***crucial*** to:
* training for longer durations * avoid re-masking tokens that are already "good" enough (this can easily be done by "banning" them in the scoring process)
* iterative inferencing (inference the first n seconds, then the next n seconds, etc.) * without this, you ***will*** get stuttering and unaligned utterances. I do not know why this is such a big problem but I imagine this "interleaves" many different sequences between each step.
* more demasking steps in inferencing * use unfiltered/unprocessed logit scores:
* training explicitly for NAR level 0 * not that crucial, but helps stability.
Despite being trained "flawlessly" (without any implementation issues), it seems to still exhibit the same issues as if it were trained erroneously (to predict the next token rather than the token in place).
## Embeddings (and Classifiers) ## Embeddings (and Classifiers)

View File

@ -790,6 +790,7 @@ class Config(BaseConfig):
sample_rate: int = 24_000 # sample rate the model expects sample_rate: int = 24_000 # sample rate the model expects
audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac"" audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac""
weights_name: str = "fp32"
weights_format: str = "sft" # "pth" | "sft" weights_format: str = "sft" # "pth" | "sft"
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"]) supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])

View File

@ -46,7 +46,7 @@ def load_engines(training=True, **model_kwargs):
checkpoint_path = cfg.ckpt_dir / name / "latest" checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) load_path = pick_path( cfg.ckpt_dir / name / f"{cfg.weights_name}.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
# actually use the lora-specific checkpoint if available # actually use the lora-specific checkpoint if available
if cfg.lora is not None: if cfg.lora is not None:

View File

@ -341,7 +341,7 @@ class Engines(dict[str, Engine]):
for name, engine in self.items(): for name, engine in self.items():
module = engine.module.state_dict() module = engine.module.state_dict()
lora = None lora = None
save_path = cfg.ckpt_dir / name / f"fp32.{format}" save_path = cfg.ckpt_dir / name / f"{cfg.weights_name}.{format}"
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
# safety # safety
@ -350,7 +350,7 @@ class Engines(dict[str, Engine]):
if cfg.lora is not None: if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True ) lora, module = lora_get_state_dict( module, split = True )
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"fp32.{format}" save_path = cfg.ckpt_dir / cfg.lora.full_name / f"{cfg.weights_name}.{format}"
state_dict = { state_dict = {
'module': module, 'module': module,

View File

@ -278,23 +278,6 @@ class AR_NAR(Base):
null_prom = [ None for _ in range(batch_size) ] null_prom = [ None for _ in range(batch_size) ]
prev_list = resps_list prev_list = resps_list
# to-do: only do the Nth first tokens, then the Nth seconds tokens, etc. until the last window
# because for longer utterances it absolutely degrades
"""
buckets = max([ l // 75 for l in len_list ])
original_len_list = [ l for l in len_list ]
for bucket in range(1,buckets+1):
len_list = [ int(l * bucket / buckets) for l in original_len_list ]
if bucket == 1:
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
else:
start_noise = 0.5
resps_list = [ torch.concat([ resps, torch.ones((seq_len - resps.shape[0],), dtype=torch.int16, device=device) * self.stop_token ]) for resps, seq_len in zip( resps_list, len_list ) ]
scores = [ torch.zeros((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
prev_list = resps_list
"""
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm): for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
# ramp down over time # ramp down over time
annealing = 1.0 - timestep annealing = 1.0 - timestep
@ -310,12 +293,9 @@ class AR_NAR(Base):
# timestep inputs # timestep inputs
time_list = [ timestep for _ in range(batch_size) ] time_list = [ timestep for _ in range(batch_size) ]
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
sampling_temperature = temperature * annealing sampling_temperature = temperature * annealing
sampling_cfg = cfg_strength * timestep sampling_cfg = cfg_strength * timestep
"""
sampling_temperature = temperature
sampling_cfg = cfg_strength
"""
# setup inputs # setup inputs
inputs = super().inputs( inputs = super().inputs(
@ -364,7 +344,6 @@ class AR_NAR(Base):
) )
# retrieves unfiltered logits # retrieves unfiltered logits
"""
unfiltered_sampled = super().sample( unfiltered_sampled = super().sample(
logits=logits, logits=logits,
prev_list=prev_list, prev_list=prev_list,
@ -373,29 +352,21 @@ class AR_NAR(Base):
temperature=0.0, temperature=0.0,
**sampling_kwargs, **sampling_kwargs,
) )
"""
# update previous list of tokens # update previous list of tokens
prev_list = resps_list prev_list = resps_list
# get sampled tokens # get sampled tokens
sampled_ids = filtered_sampled.ids sampled_ids = filtered_sampled.ids
# keep unmasked tokens # keep unmasked tokens
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
# get probability scores (conjugate to have worse scoring tokens picked for topk) # get probability scores
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ] scores = [
# conjugate to have worse scoring tokens picked for topk
""" 1.0 -
# maskgct does some funny stuff but it doesn't amount to anything # only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
if annealing < 1.0e-3: torch.where( masked, torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked.shape, device=device) )
sampled_ids = filtered_sampled.ids # use unmodified logit scores for this, as it offers better stability
else: for scores, masked in zip( unfiltered_sampled.scores, is_masked )
sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits ] ]
# keep unmasked tokens
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
# update scores (conjugated to put the worst scores at the top)
scores = [ torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
scores = [ 1.0 - (choice_temperature * annealing * gumbel_noise( score ) + score) for score in scores ]
"""
return resps_list return resps_list

View File

@ -1588,7 +1588,6 @@ class Base(nn.Module):
mask = torch.cat([mask, padding], dim=1) mask = torch.cat([mask, padding], dim=1)
# needs to be done here as we still have our raw inputs # needs to be done here as we still have our raw inputs
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
classifier_levels = self.get_input( inputs, name="classifier_level" ) classifier_levels = self.get_input( inputs, name="classifier_level" )