diff --git a/codes/scripts/byol/byol_extract_wrapped_model.py b/codes/scripts/byol/byol_extract_wrapped_model.py index 644ade09..e002870c 100644 --- a/codes/scripts/byol/byol_extract_wrapped_model.py +++ b/codes/scripts/byol/byol_extract_wrapped_model.py @@ -2,16 +2,20 @@ import torch from models.spinenet_arch import SpineNet +def extract_byol_model_from_state_dict(sd): + wrap_key = 'online_encoder.net.' + sdo = {} + for k,v in sd.items(): + if wrap_key in k: + sdo[k.replace(wrap_key, '')] = v + return sdo + if __name__ == '__main__': pretrained_path = '../../../experiments/uresnet_pixpro4_imgset.pth' output_path = '../../../experiments/uresnet_pixpro4_imgset.pth' - wrap_key = 'online_encoder.net.' sd = torch.load(pretrained_path) - sdo = {} - for k,v in sd.items(): - if wrap_key in k: - sdo[k.replace(wrap_key, '')] = v + sd = extract_byol_model_from_state_dict(sd) #model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda') #model.load_state_dict(sdo, strict=True)