actually float16(+AMP) and layerskip is bad and will kill the model......

This commit is contained in:
mrq 2024-11-01 18:36:44 -05:00
parent edf1e66bf9
commit fb8faa295b
3 changed files with 11 additions and 15 deletions

View File

@ -948,15 +948,6 @@ class Config(BaseConfig):
_logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}")
del model["interleave"]
if "p_rvq_levels" in model["experimental"] and "rvq_levels_p" not in model["experimental"]:
_logger.warning(f"Deprecated flag found: {'cfg.model.experimental.p_rvq_levels'}")
model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"]
del model["experimental"]["p_rvq_levels"]
if "audio_embedding_sums" in model:
_logger.warning(f"Deprecated flag found: {'cfg.model.p_rvq_levels'}")
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]

View File

@ -445,8 +445,8 @@ class LlamaModel_Adapted(LlamaModel):
if not self.dropoff_layer( l ):
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)

View File

@ -196,9 +196,9 @@ def train():
# copy config yaml to backup
if cfg.yaml_path is not None and is_global_leader():
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
# create dataloaders
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
# evaluation lambda
def eval_fn(engines):
do_gc()
engines.eval()
@ -213,18 +213,23 @@ def train():
engines.train()
qnt.unload_model()
do_gc()
# unload EnCodec if it's already loaded
qnt.unload_model()
# only eval is requested
if args.eval:
return eval_fn(engines=trainer.load_engines())
"""
# start web UI
if cfg.trainer.load_webui:
from .webui import start
start(lock=False)
"""
# pre-training config validation
if cfg.model.experimental.layerskip and cfg.trainer.weight_dtype == "float16":
_logger.warning(f"Training with LayerSkip enabled with float16 will result in frying the model. Please use bfloat16.")
# train
trainer.train(
train_dl=train_dl,
train_feeder=train_feeder,