Export extract_byol_model as a function

This commit is contained in:
James Betker 2021-08-05 22:15:26 -06:00
parent 89d15c9e74
commit f86df53ce0

View File

@ -2,16 +2,20 @@ import torch
from models.spinenet_arch import SpineNet from models.spinenet_arch import SpineNet
def extract_byol_model_from_state_dict(sd):
wrap_key = 'online_encoder.net.'
sdo = {}
for k,v in sd.items():
if wrap_key in k:
sdo[k.replace(wrap_key, '')] = v
return sdo
if __name__ == '__main__': if __name__ == '__main__':
pretrained_path = '../../../experiments/uresnet_pixpro4_imgset.pth' pretrained_path = '../../../experiments/uresnet_pixpro4_imgset.pth'
output_path = '../../../experiments/uresnet_pixpro4_imgset.pth' output_path = '../../../experiments/uresnet_pixpro4_imgset.pth'
wrap_key = 'online_encoder.net.'
sd = torch.load(pretrained_path) sd = torch.load(pretrained_path)
sdo = {} sd = extract_byol_model_from_state_dict(sd)
for k,v in sd.items():
if wrap_key in k:
sdo[k.replace(wrap_key, '')] = v
#model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda') #model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda')
#model.load_state_dict(sdo, strict=True) #model.load_state_dict(sdo, strict=True)