Reduce peak memory usage when changing models
A few tweaks to reduce peak memory usage, the biggest being that if we aren't using the checkpoint cache, we shouldn't duplicate the model state dict just to immediately throw it away. On my machine with 16GB of RAM, this change means I can typically change models, whereas before it would typically OOM.
This commit is contained in:
parent
737eb28fac
commit
b50ff4f4e4
|
@ -170,7 +170,9 @@ def load_model_weights(model, checkpoint_info):
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
|
||||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||||
missing, extra = model.load_state_dict(sd, strict=False)
|
del pl_sd
|
||||||
|
model.load_state_dict(sd, strict=False)
|
||||||
|
del sd
|
||||||
|
|
||||||
if shared.cmd_opts.opt_channelslast:
|
if shared.cmd_opts.opt_channelslast:
|
||||||
model.to(memory_format=torch.channels_last)
|
model.to(memory_format=torch.channels_last)
|
||||||
|
@ -194,6 +196,7 @@ def load_model_weights(model, checkpoint_info):
|
||||||
|
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
|
||||||
|
if shared.opts.sd_checkpoint_cache > 0:
|
||||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||||
checkpoints_loaded.popitem(last=False) # LRU
|
checkpoints_loaded.popitem(last=False) # LRU
|
||||||
|
|
Loading…
Reference in New Issue
Block a user