From 190a917b3ed440b27853ddffb6cb95f854b36e09 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 19 Nov 2024 12:24:33 -0600 Subject: [PATCH] I did it. --- docs/models.md | 12 ++++------ vall_e/config.py | 1 + vall_e/engines/__init__.py | 2 +- vall_e/engines/base.py | 4 ++-- vall_e/models/ar_nar.py | 49 ++++++++------------------------------ vall_e/models/base.py | 1 - 6 files changed, 19 insertions(+), 50 deletions(-) diff --git a/docs/models.md b/docs/models.md index c203155..1a22ee7 100644 --- a/docs/models.md +++ b/docs/models.md @@ -73,13 +73,11 @@ In theory, demasking for the NAR's RVQ level 0 can also be applied to the remain * this isn't necessary as the model already has a strong enough relationship between the prompt, the prior levels, and the targeted level. * this is technically already offered with `cfg.model.experimental.token_dropout_rate` which mirrors masking, but experimentation has not been done to a large degree. -Unfortunately, this model does not seem to prove stable for longer utterances, and the following bandaid copes do not solve this: -* training for longer durations -* iterative inferencing (inference the first n seconds, then the next n seconds, etc.) -* more demasking steps in inferencing -* training explicitly for NAR level 0 - -Despite being trained "flawlessly" (without any implementation issues), it seems to still exhibit the same issues as if it were trained erroneously (to predict the next token rather than the token in place). +It is ***crucial*** to: +* avoid re-masking tokens that are already "good" enough (this can easily be done by "banning" them in the scoring process) + * without this, you ***will*** get stuttering and unaligned utterances. I do not know why this is such a big problem but I imagine this "interleaves" many different sequences between each step. +* use unfiltered/unprocessed logit scores: + * not that crucial, but helps stability. ## Embeddings (and Classifiers) diff --git a/vall_e/config.py b/vall_e/config.py index 85eb89d..dff2690 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -790,6 +790,7 @@ class Config(BaseConfig): sample_rate: int = 24_000 # sample rate the model expects audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac"" + weights_name: str = "fp32" weights_format: str = "sft" # "pth" | "sft" supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"]) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 86fb68e..a3d6e29 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -46,7 +46,7 @@ def load_engines(training=True, **model_kwargs): checkpoint_path = cfg.ckpt_dir / name / "latest" # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present - load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) + load_path = pick_path( cfg.ckpt_dir / name / f"{cfg.weights_name}.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) # actually use the lora-specific checkpoint if available if cfg.lora is not None: diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index ce065c2..0302346 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -341,7 +341,7 @@ class Engines(dict[str, Engine]): for name, engine in self.items(): module = engine.module.state_dict() lora = None - save_path = cfg.ckpt_dir / name / f"fp32.{format}" + save_path = cfg.ckpt_dir / name / f"{cfg.weights_name}.{format}" config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config # safety @@ -350,7 +350,7 @@ class Engines(dict[str, Engine]): if cfg.lora is not None: lora, module = lora_get_state_dict( module, split = True ) - save_path = cfg.ckpt_dir / cfg.lora.full_name / f"fp32.{format}" + save_path = cfg.ckpt_dir / cfg.lora.full_name / f"{cfg.weights_name}.{format}" state_dict = { 'module': module, diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 55a1cc4..3945335 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -278,23 +278,6 @@ class AR_NAR(Base): null_prom = [ None for _ in range(batch_size) ] prev_list = resps_list - # to-do: only do the Nth first tokens, then the Nth seconds tokens, etc. until the last window - # because for longer utterances it absolutely degrades - """ - buckets = max([ l // 75 for l in len_list ]) - original_len_list = [ l for l in len_list ] - for bucket in range(1,buckets+1): - len_list = [ int(l * bucket / buckets) for l in original_len_list ] - if bucket == 1: - resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ] - else: - start_noise = 0.5 - resps_list = [ torch.concat([ resps, torch.ones((seq_len - resps.shape[0],), dtype=torch.int16, device=device) * self.stop_token ]) for resps, seq_len in zip( resps_list, len_list ) ] - - scores = [ torch.zeros((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ] - prev_list = resps_list - """ - for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm): # ramp down over time annealing = 1.0 - timestep @@ -310,12 +293,9 @@ class AR_NAR(Base): # timestep inputs time_list = [ timestep for _ in range(batch_size) ] + # greedy sampling is very, very much preferred, but using greedy logit scores later helps enough sampling_temperature = temperature * annealing sampling_cfg = cfg_strength * timestep - """ - sampling_temperature = temperature - sampling_cfg = cfg_strength - """ # setup inputs inputs = super().inputs( @@ -364,7 +344,6 @@ class AR_NAR(Base): ) # retrieves unfiltered logits - """ unfiltered_sampled = super().sample( logits=logits, prev_list=prev_list, @@ -373,29 +352,21 @@ class AR_NAR(Base): temperature=0.0, **sampling_kwargs, ) - """ # update previous list of tokens prev_list = resps_list # get sampled tokens sampled_ids = filtered_sampled.ids # keep unmasked tokens resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] - # get probability scores (conjugate to have worse scoring tokens picked for topk) - scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ] - - """ - # maskgct does some funny stuff but it doesn't amount to anything - if annealing < 1.0e-3: - sampled_ids = filtered_sampled.ids - else: - sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits ] - - # keep unmasked tokens - resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] - # update scores (conjugated to put the worst scores at the top) - scores = [ torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ] - scores = [ 1.0 - (choice_temperature * annealing * gumbel_noise( score ) + score) for score in scores ] - """ + # get probability scores + scores = [ + # conjugate to have worse scoring tokens picked for topk + 1.0 - + # only keep scores of tokens we are predicting (and ignore the tokens previously finalized) + torch.where( masked, torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked.shape, device=device) ) + # use unmodified logit scores for this, as it offers better stability + for scores, masked in zip( unfiltered_sampled.scores, is_masked ) + ] return resps_list diff --git a/vall_e/models/base.py b/vall_e/models/base.py index cca36d7..30e475d 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1588,7 +1588,6 @@ class Base(nn.Module): mask = torch.cat([mask, padding], dim=1) # needs to be done here as we still have our raw inputs - #position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None classifier_levels = self.get_input( inputs, name="classifier_level" )