also cull frozen_params in the params optimizer receives to reduce VRAM it consumes

This commit is contained in:
mrq 2023-09-07 18:27:02 -05:00
parent 8837bc34d7
commit 67617d7d69

View File

@ -73,7 +73,7 @@ def load_engines(invert=False):
}
params.update(cfg.hyperparameters.optimizer_params)
optimizer = ml.AdamW(
model.parameters(),
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
**params,
)
elif cfg.hyperparameters.optimizer.lower() == "sgd":
@ -82,7 +82,7 @@ def load_engines(invert=False):
}
params.update(cfg.hyperparameters.optimizer_params)
optimizer = ml.SGD(
model.parameters(),
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
**params,
)
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
@ -91,7 +91,7 @@ def load_engines(invert=False):
}
params.update(cfg.hyperparameters.optimizer_params)
optimizer = ml.Prodigy(
model.parameters(),
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
**params,
)