From 67617d7d695e9196ad89ce54d335344c8c9f2168 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 7 Sep 2023 18:27:02 -0500 Subject: [PATCH] also cull frozen_params in the params optimizer receives to reduce VRAM it consumes --- vall_e/utils/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index c5b693a..d87df60 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -73,7 +73,7 @@ def load_engines(invert=False): } params.update(cfg.hyperparameters.optimizer_params) optimizer = ml.AdamW( - model.parameters(), + [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], **params, ) elif cfg.hyperparameters.optimizer.lower() == "sgd": @@ -82,7 +82,7 @@ def load_engines(invert=False): } params.update(cfg.hyperparameters.optimizer_params) optimizer = ml.SGD( - model.parameters(), + [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], **params, ) elif cfg.hyperparameters.optimizer.lower() == "prodigy": @@ -91,7 +91,7 @@ def load_engines(invert=False): } params.update(cfg.hyperparameters.optimizer_params) optimizer = ml.Prodigy( - model.parameters(), + [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], **params, )