I did it.
This commit is contained in:
parent
0e621354e7
commit
190a917b3e
|
@ -73,13 +73,11 @@ In theory, demasking for the NAR's RVQ level 0 can also be applied to the remain
|
||||||
* this isn't necessary as the model already has a strong enough relationship between the prompt, the prior levels, and the targeted level.
|
* this isn't necessary as the model already has a strong enough relationship between the prompt, the prior levels, and the targeted level.
|
||||||
* this is technically already offered with `cfg.model.experimental.token_dropout_rate` which mirrors masking, but experimentation has not been done to a large degree.
|
* this is technically already offered with `cfg.model.experimental.token_dropout_rate` which mirrors masking, but experimentation has not been done to a large degree.
|
||||||
|
|
||||||
Unfortunately, this model does not seem to prove stable for longer utterances, and the following bandaid copes do not solve this:
|
It is ***crucial*** to:
|
||||||
* training for longer durations
|
* avoid re-masking tokens that are already "good" enough (this can easily be done by "banning" them in the scoring process)
|
||||||
* iterative inferencing (inference the first n seconds, then the next n seconds, etc.)
|
* without this, you ***will*** get stuttering and unaligned utterances. I do not know why this is such a big problem but I imagine this "interleaves" many different sequences between each step.
|
||||||
* more demasking steps in inferencing
|
* use unfiltered/unprocessed logit scores:
|
||||||
* training explicitly for NAR level 0
|
* not that crucial, but helps stability.
|
||||||
|
|
||||||
Despite being trained "flawlessly" (without any implementation issues), it seems to still exhibit the same issues as if it were trained erroneously (to predict the next token rather than the token in place).
|
|
||||||
|
|
||||||
## Embeddings (and Classifiers)
|
## Embeddings (and Classifiers)
|
||||||
|
|
||||||
|
|
|
@ -790,6 +790,7 @@ class Config(BaseConfig):
|
||||||
sample_rate: int = 24_000 # sample rate the model expects
|
sample_rate: int = 24_000 # sample rate the model expects
|
||||||
audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac""
|
audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac""
|
||||||
|
|
||||||
|
weights_name: str = "fp32"
|
||||||
weights_format: str = "sft" # "pth" | "sft"
|
weights_format: str = "sft" # "pth" | "sft"
|
||||||
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
|
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@ def load_engines(training=True, **model_kwargs):
|
||||||
|
|
||||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||||
load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
load_path = pick_path( cfg.ckpt_dir / name / f"{cfg.weights_name}.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||||
|
|
||||||
# actually use the lora-specific checkpoint if available
|
# actually use the lora-specific checkpoint if available
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
|
|
|
@ -341,7 +341,7 @@ class Engines(dict[str, Engine]):
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
module = engine.module.state_dict()
|
module = engine.module.state_dict()
|
||||||
lora = None
|
lora = None
|
||||||
save_path = cfg.ckpt_dir / name / f"fp32.{format}"
|
save_path = cfg.ckpt_dir / name / f"{cfg.weights_name}.{format}"
|
||||||
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
|
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
|
||||||
|
|
||||||
# safety
|
# safety
|
||||||
|
@ -350,7 +350,7 @@ class Engines(dict[str, Engine]):
|
||||||
|
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
lora, module = lora_get_state_dict( module, split = True )
|
lora, module = lora_get_state_dict( module, split = True )
|
||||||
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"fp32.{format}"
|
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"{cfg.weights_name}.{format}"
|
||||||
|
|
||||||
state_dict = {
|
state_dict = {
|
||||||
'module': module,
|
'module': module,
|
||||||
|
|
|
@ -278,23 +278,6 @@ class AR_NAR(Base):
|
||||||
null_prom = [ None for _ in range(batch_size) ]
|
null_prom = [ None for _ in range(batch_size) ]
|
||||||
prev_list = resps_list
|
prev_list = resps_list
|
||||||
|
|
||||||
# to-do: only do the Nth first tokens, then the Nth seconds tokens, etc. until the last window
|
|
||||||
# because for longer utterances it absolutely degrades
|
|
||||||
"""
|
|
||||||
buckets = max([ l // 75 for l in len_list ])
|
|
||||||
original_len_list = [ l for l in len_list ]
|
|
||||||
for bucket in range(1,buckets+1):
|
|
||||||
len_list = [ int(l * bucket / buckets) for l in original_len_list ]
|
|
||||||
if bucket == 1:
|
|
||||||
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
|
|
||||||
else:
|
|
||||||
start_noise = 0.5
|
|
||||||
resps_list = [ torch.concat([ resps, torch.ones((seq_len - resps.shape[0],), dtype=torch.int16, device=device) * self.stop_token ]) for resps, seq_len in zip( resps_list, len_list ) ]
|
|
||||||
|
|
||||||
scores = [ torch.zeros((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
|
|
||||||
prev_list = resps_list
|
|
||||||
"""
|
|
||||||
|
|
||||||
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
|
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
|
||||||
# ramp down over time
|
# ramp down over time
|
||||||
annealing = 1.0 - timestep
|
annealing = 1.0 - timestep
|
||||||
|
@ -310,12 +293,9 @@ class AR_NAR(Base):
|
||||||
# timestep inputs
|
# timestep inputs
|
||||||
time_list = [ timestep for _ in range(batch_size) ]
|
time_list = [ timestep for _ in range(batch_size) ]
|
||||||
|
|
||||||
|
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
|
||||||
sampling_temperature = temperature * annealing
|
sampling_temperature = temperature * annealing
|
||||||
sampling_cfg = cfg_strength * timestep
|
sampling_cfg = cfg_strength * timestep
|
||||||
"""
|
|
||||||
sampling_temperature = temperature
|
|
||||||
sampling_cfg = cfg_strength
|
|
||||||
"""
|
|
||||||
|
|
||||||
# setup inputs
|
# setup inputs
|
||||||
inputs = super().inputs(
|
inputs = super().inputs(
|
||||||
|
@ -364,7 +344,6 @@ class AR_NAR(Base):
|
||||||
)
|
)
|
||||||
|
|
||||||
# retrieves unfiltered logits
|
# retrieves unfiltered logits
|
||||||
"""
|
|
||||||
unfiltered_sampled = super().sample(
|
unfiltered_sampled = super().sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
prev_list=prev_list,
|
prev_list=prev_list,
|
||||||
|
@ -373,29 +352,21 @@ class AR_NAR(Base):
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
# update previous list of tokens
|
# update previous list of tokens
|
||||||
prev_list = resps_list
|
prev_list = resps_list
|
||||||
# get sampled tokens
|
# get sampled tokens
|
||||||
sampled_ids = filtered_sampled.ids
|
sampled_ids = filtered_sampled.ids
|
||||||
# keep unmasked tokens
|
# keep unmasked tokens
|
||||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||||
# get probability scores (conjugate to have worse scoring tokens picked for topk)
|
# get probability scores
|
||||||
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
|
scores = [
|
||||||
|
# conjugate to have worse scoring tokens picked for topk
|
||||||
"""
|
1.0 -
|
||||||
# maskgct does some funny stuff but it doesn't amount to anything
|
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
|
||||||
if annealing < 1.0e-3:
|
torch.where( masked, torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked.shape, device=device) )
|
||||||
sampled_ids = filtered_sampled.ids
|
# use unmodified logit scores for this, as it offers better stability
|
||||||
else:
|
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
|
||||||
sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits ]
|
]
|
||||||
|
|
||||||
# keep unmasked tokens
|
|
||||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
|
||||||
# update scores (conjugated to put the worst scores at the top)
|
|
||||||
scores = [ torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
|
|
||||||
scores = [ 1.0 - (choice_temperature * annealing * gumbel_noise( score ) + score) for score in scores ]
|
|
||||||
"""
|
|
||||||
|
|
||||||
return resps_list
|
return resps_list
|
||||||
|
|
||||||
|
|
|
@ -1588,7 +1588,6 @@ class Base(nn.Module):
|
||||||
mask = torch.cat([mask, padding], dim=1)
|
mask = torch.cat([mask, padding], dim=1)
|
||||||
|
|
||||||
# needs to be done here as we still have our raw inputs
|
# needs to be done here as we still have our raw inputs
|
||||||
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
|
||||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||||
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user