also cull frozen_params in the params optimizer receives to reduce VRAM it consumes
This commit is contained in:
parent
8837bc34d7
commit
67617d7d69
|
@ -73,7 +73,7 @@ def load_engines(invert=False):
|
|||
}
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
optimizer = ml.AdamW(
|
||||
model.parameters(),
|
||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||
**params,
|
||||
)
|
||||
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||
|
@ -82,7 +82,7 @@ def load_engines(invert=False):
|
|||
}
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
optimizer = ml.SGD(
|
||||
model.parameters(),
|
||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||
**params,
|
||||
)
|
||||
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
||||
|
@ -91,7 +91,7 @@ def load_engines(invert=False):
|
|||
}
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
optimizer = ml.Prodigy(
|
||||
model.parameters(),
|
||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||
**params,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user