Reduce peak memory usage when changing models

A few tweaks to reduce peak memory usage, the biggest being that if we
aren't using the checkpoint cache, we shouldn't duplicate the model
state dict just to immediately throw it away.

On my machine with 16GB of RAM, this change means I can typically change
models, whereas before it would typically OOM.
This commit is contained in:
Josh Watzman 2022-10-27 21:59:16 +01:00
parent 737eb28fac
commit b50ff4f4e4

View File

@ -170,7 +170,9 @@ def load_model_weights(model, checkpoint_info):
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd) sd = get_state_dict_from_checkpoint(pl_sd)
missing, extra = model.load_state_dict(sd, strict=False) del pl_sd
model.load_state_dict(sd, strict=False)
del sd
if shared.cmd_opts.opt_channelslast: if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last) model.to(memory_format=torch.channels_last)
@ -194,9 +196,10 @@ def load_model_weights(model, checkpoint_info):
model.first_stage_model.to(devices.dtype_vae) model.first_stage_model.to(devices.dtype_vae)
checkpoints_loaded[checkpoint_info] = model.state_dict().copy() if shared.opts.sd_checkpoint_cache > 0:
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
checkpoints_loaded.popitem(last=False) # LRU while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
else: else:
print(f"Loading weights [{sd_model_hash}] from cache") print(f"Loading weights [{sd_model_hash}] from cache")
checkpoints_loaded.move_to_end(checkpoint_info) checkpoints_loaded.move_to_end(checkpoint_info)