forked from mrq/DL-Art-School
Export extract_byol_model as a function
This commit is contained in:
parent
89d15c9e74
commit
f86df53ce0
|
@ -2,16 +2,20 @@ import torch
|
|||
|
||||
from models.spinenet_arch import SpineNet
|
||||
|
||||
def extract_byol_model_from_state_dict(sd):
|
||||
wrap_key = 'online_encoder.net.'
|
||||
sdo = {}
|
||||
for k,v in sd.items():
|
||||
if wrap_key in k:
|
||||
sdo[k.replace(wrap_key, '')] = v
|
||||
return sdo
|
||||
|
||||
if __name__ == '__main__':
|
||||
pretrained_path = '../../../experiments/uresnet_pixpro4_imgset.pth'
|
||||
output_path = '../../../experiments/uresnet_pixpro4_imgset.pth'
|
||||
|
||||
wrap_key = 'online_encoder.net.'
|
||||
sd = torch.load(pretrained_path)
|
||||
sdo = {}
|
||||
for k,v in sd.items():
|
||||
if wrap_key in k:
|
||||
sdo[k.replace(wrap_key, '')] = v
|
||||
sd = extract_byol_model_from_state_dict(sd)
|
||||
|
||||
#model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda')
|
||||
#model.load_state_dict(sdo, strict=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user