actually float16(+AMP) and layerskip is bad and will kill the model......

This commit is contained in:
mrq 2024-11-01 18:36:44 -05:00
parent edf1e66bf9
commit fb8faa295b
3 changed files with 11 additions and 15 deletions

View File

@ -948,15 +948,6 @@ class Config(BaseConfig):
_logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}") _logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}")
del model["interleave"] del model["interleave"]
if "p_rvq_levels" in model["experimental"] and "rvq_levels_p" not in model["experimental"]:
_logger.warning(f"Deprecated flag found: {'cfg.model.experimental.p_rvq_levels'}")
model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"]
del model["experimental"]["p_rvq_levels"]
if "audio_embedding_sums" in model:
_logger.warning(f"Deprecated flag found: {'cfg.model.p_rvq_levels'}")
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ] self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ] self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]

View File

@ -196,9 +196,9 @@ def train():
# copy config yaml to backup # copy config yaml to backup
if cfg.yaml_path is not None and is_global_leader(): if cfg.yaml_path is not None and is_global_leader():
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" ) shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
# create dataloaders
train_dl, subtrain_dl, val_dl = create_train_val_dataloader() train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
# evaluation lambda
def eval_fn(engines): def eval_fn(engines):
do_gc() do_gc()
engines.eval() engines.eval()
@ -213,18 +213,23 @@ def train():
engines.train() engines.train()
qnt.unload_model() qnt.unload_model()
do_gc() do_gc()
# unload EnCodec if it's already loaded
qnt.unload_model() qnt.unload_model()
# only eval is requested
if args.eval: if args.eval:
return eval_fn(engines=trainer.load_engines()) return eval_fn(engines=trainer.load_engines())
""" """
# start web UI
if cfg.trainer.load_webui: if cfg.trainer.load_webui:
from .webui import start from .webui import start
start(lock=False) start(lock=False)
""" """
# pre-training config validation
if cfg.model.experimental.layerskip and cfg.trainer.weight_dtype == "float16":
_logger.warning(f"Training with LayerSkip enabled with float16 will result in frying the model. Please use bfloat16.")
# train
trainer.train( trainer.train(
train_dl=train_dl, train_dl=train_dl,
train_feeder=train_feeder, train_feeder=train_feeder,