diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index c5b693a..d87df60 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -73,7 +73,7 @@ def load_engines(invert=False): } params.update(cfg.hyperparameters.optimizer_params) optimizer = ml.AdamW( - model.parameters(), + [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], **params, ) elif cfg.hyperparameters.optimizer.lower() == "sgd": @@ -82,7 +82,7 @@ def load_engines(invert=False): } params.update(cfg.hyperparameters.optimizer_params) optimizer = ml.SGD( - model.parameters(), + [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], **params, ) elif cfg.hyperparameters.optimizer.lower() == "prodigy": @@ -91,7 +91,7 @@ def load_engines(invert=False): } params.update(cfg.hyperparameters.optimizer_params) optimizer = ml.Prodigy( - model.parameters(), + [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], **params, )