actually float16(+AMP) and layerskip is bad and will kill the model......
This commit is contained in:
parent
edf1e66bf9
commit
fb8faa295b
|
@ -948,15 +948,6 @@ class Config(BaseConfig):
|
||||||
_logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}")
|
_logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}")
|
||||||
del model["interleave"]
|
del model["interleave"]
|
||||||
|
|
||||||
if "p_rvq_levels" in model["experimental"] and "rvq_levels_p" not in model["experimental"]:
|
|
||||||
_logger.warning(f"Deprecated flag found: {'cfg.model.experimental.p_rvq_levels'}")
|
|
||||||
model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"]
|
|
||||||
del model["experimental"]["p_rvq_levels"]
|
|
||||||
|
|
||||||
if "audio_embedding_sums" in model:
|
|
||||||
_logger.warning(f"Deprecated flag found: {'cfg.model.p_rvq_levels'}")
|
|
||||||
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
|
|
||||||
|
|
||||||
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
|
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
|
||||||
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]
|
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]
|
||||||
|
|
||||||
|
|
|
@ -196,9 +196,9 @@ def train():
|
||||||
# copy config yaml to backup
|
# copy config yaml to backup
|
||||||
if cfg.yaml_path is not None and is_global_leader():
|
if cfg.yaml_path is not None and is_global_leader():
|
||||||
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
|
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
|
||||||
|
# create dataloaders
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
|
# evaluation lambda
|
||||||
def eval_fn(engines):
|
def eval_fn(engines):
|
||||||
do_gc()
|
do_gc()
|
||||||
engines.eval()
|
engines.eval()
|
||||||
|
@ -213,18 +213,23 @@ def train():
|
||||||
engines.train()
|
engines.train()
|
||||||
qnt.unload_model()
|
qnt.unload_model()
|
||||||
do_gc()
|
do_gc()
|
||||||
|
# unload EnCodec if it's already loaded
|
||||||
qnt.unload_model()
|
qnt.unload_model()
|
||||||
|
# only eval is requested
|
||||||
if args.eval:
|
if args.eval:
|
||||||
return eval_fn(engines=trainer.load_engines())
|
return eval_fn(engines=trainer.load_engines())
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
# start web UI
|
||||||
if cfg.trainer.load_webui:
|
if cfg.trainer.load_webui:
|
||||||
from .webui import start
|
from .webui import start
|
||||||
start(lock=False)
|
start(lock=False)
|
||||||
"""
|
"""
|
||||||
|
# pre-training config validation
|
||||||
|
if cfg.model.experimental.layerskip and cfg.trainer.weight_dtype == "float16":
|
||||||
|
_logger.warning(f"Training with LayerSkip enabled with float16 will result in frying the model. Please use bfloat16.")
|
||||||
|
|
||||||
|
# train
|
||||||
trainer.train(
|
trainer.train(
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
train_feeder=train_feeder,
|
train_feeder=train_feeder,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user