make main model loading and model merger use the same code
This commit is contained in:
parent
050a6a798c
commit
c77c89cc83
|
@ -169,9 +169,9 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
|
||||||
|
|
||||||
print(f"Loading {secondary_model_info.filename}...")
|
print(f"Loading {secondary_model_info.filename}...")
|
||||||
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
|
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
theta_0 = primary_model['state_dict']
|
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
|
||||||
theta_1 = secondary_model['state_dict']
|
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
|
||||||
|
|
||||||
theta_funcs = {
|
theta_funcs = {
|
||||||
"Weighted Sum": weighted_sum,
|
"Weighted Sum": weighted_sum,
|
||||||
|
|
|
@ -122,6 +122,13 @@ def select_checkpoint():
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
|
def get_state_dict_from_checkpoint(pl_sd):
|
||||||
|
if "state_dict" in pl_sd:
|
||||||
|
return pl_sd["state_dict"]
|
||||||
|
|
||||||
|
return pl_sd
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_info):
|
def load_model_weights(model, checkpoint_info):
|
||||||
checkpoint_file = checkpoint_info.filename
|
checkpoint_file = checkpoint_info.filename
|
||||||
sd_model_hash = checkpoint_info.hash
|
sd_model_hash = checkpoint_info.hash
|
||||||
|
@ -131,11 +138,8 @@ def load_model_weights(model, checkpoint_info):
|
||||||
pl_sd = torch.load(checkpoint_file, map_location="cpu")
|
pl_sd = torch.load(checkpoint_file, map_location="cpu")
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
|
||||||
if "state_dict" in pl_sd:
|
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
else:
|
|
||||||
sd = pl_sd
|
|
||||||
|
|
||||||
model.load_state_dict(sd, strict=False)
|
model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user