diff --git a/vall_e/config.py b/vall_e/config.py index 513af97..ec0a7a3 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -948,6 +948,7 @@ class Config(BaseConfig): if not isinstance( model, dict ): continue + # to-do: prune unused keys in here too automatically if "experimental" not in model or not model["experimental"]: model["experimental"] = {} @@ -962,6 +963,9 @@ class Config(BaseConfig): if "p_rvq_levels" in model["experimental"]: model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"] del model["experimental"]["p_rvq_levels"] + + if "p_len_train" in model["experimental"]: + del model["experimental"]["p_len_train"] self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ] self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ] diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index a9dd77f..e898b61 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -275,9 +275,9 @@ class AR_NAR(Base): _super = super() # to-do: allow for batch processing (it should probably work batched anyways) def demask_sampling( batch_index, seq_len ): - # overrides + # overrides, to be user-controllable soonTM max_steps = 10 - temperature = 0.3 + temperature = 1.0 cfg_strength = 1.0 sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied...... sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9