diff --git a/vall_e/config.py b/vall_e/config.py index db10966..8dbc698 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -948,15 +948,6 @@ class Config(BaseConfig): _logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}") del model["interleave"] - if "p_rvq_levels" in model["experimental"] and "rvq_levels_p" not in model["experimental"]: - _logger.warning(f"Deprecated flag found: {'cfg.model.experimental.p_rvq_levels'}") - model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"] - del model["experimental"]["p_rvq_levels"] - - if "audio_embedding_sums" in model: - _logger.warning(f"Deprecated flag found: {'cfg.model.p_rvq_levels'}") - model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums") - self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ] self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ] diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 8309103..4efefa7 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -445,8 +445,8 @@ class LlamaModel_Adapted(LlamaModel): if not self.dropoff_layer( l ): hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) diff --git a/vall_e/train.py b/vall_e/train.py index 39cf18d..0f7c7ee 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -196,9 +196,9 @@ def train(): # copy config yaml to backup if cfg.yaml_path is not None and is_global_leader(): shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" ) - + # create dataloaders train_dl, subtrain_dl, val_dl = create_train_val_dataloader() - + # evaluation lambda def eval_fn(engines): do_gc() engines.eval() @@ -213,18 +213,23 @@ def train(): engines.train() qnt.unload_model() do_gc() - + # unload EnCodec if it's already loaded qnt.unload_model() - + # only eval is requested if args.eval: return eval_fn(engines=trainer.load_engines()) """ + # start web UI if cfg.trainer.load_webui: from .webui import start start(lock=False) """ + # pre-training config validation + if cfg.model.experimental.layerskip and cfg.trainer.weight_dtype == "float16": + _logger.warning(f"Training with LayerSkip enabled with float16 will result in frying the model. Please use bfloat16.") + # train trainer.train( train_dl=train_dl, train_feeder=train_feeder,