actually float16(+AMP) and layerskip is bad and will kill the model......
This commit is contained in:
parent
edf1e66bf9
commit
fb8faa295b
|
@ -948,15 +948,6 @@ class Config(BaseConfig):
|
|||
_logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}")
|
||||
del model["interleave"]
|
||||
|
||||
if "p_rvq_levels" in model["experimental"] and "rvq_levels_p" not in model["experimental"]:
|
||||
_logger.warning(f"Deprecated flag found: {'cfg.model.experimental.p_rvq_levels'}")
|
||||
model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"]
|
||||
del model["experimental"]["p_rvq_levels"]
|
||||
|
||||
if "audio_embedding_sums" in model:
|
||||
_logger.warning(f"Deprecated flag found: {'cfg.model.p_rvq_levels'}")
|
||||
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
|
||||
|
||||
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
|
||||
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]
|
||||
|
||||
|
|
|
@ -445,8 +445,8 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
if not self.dropoff_layer( l ):
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
|
|
@ -196,9 +196,9 @@ def train():
|
|||
# copy config yaml to backup
|
||||
if cfg.yaml_path is not None and is_global_leader():
|
||||
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
|
||||
|
||||
# create dataloaders
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
||||
# evaluation lambda
|
||||
def eval_fn(engines):
|
||||
do_gc()
|
||||
engines.eval()
|
||||
|
@ -213,18 +213,23 @@ def train():
|
|||
engines.train()
|
||||
qnt.unload_model()
|
||||
do_gc()
|
||||
|
||||
# unload EnCodec if it's already loaded
|
||||
qnt.unload_model()
|
||||
|
||||
# only eval is requested
|
||||
if args.eval:
|
||||
return eval_fn(engines=trainer.load_engines())
|
||||
|
||||
"""
|
||||
# start web UI
|
||||
if cfg.trainer.load_webui:
|
||||
from .webui import start
|
||||
start(lock=False)
|
||||
"""
|
||||
# pre-training config validation
|
||||
if cfg.model.experimental.layerskip and cfg.trainer.weight_dtype == "float16":
|
||||
_logger.warning(f"Training with LayerSkip enabled with float16 will result in frying the model. Please use bfloat16.")
|
||||
|
||||
# train
|
||||
trainer.train(
|
||||
train_dl=train_dl,
|
||||
train_feeder=train_feeder,
|
||||
|
|
Loading…
Reference in New Issue
Block a user