diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 3f4e45a..40b806a 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -97,7 +97,8 @@ def load_engines(training=True): optimizer = scheduler_class( [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], - lr = params['lr'] + lr = params['lr'], + warmup_steps = cfg.hyperparameters.warmup_steps )