2020-12-10 16:57:52 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
2021-08-06 04:15:26 +00:00
|
|
|
def extract_byol_model_from_state_dict(sd):
|
2020-12-10 16:57:52 +00:00
|
|
|
wrap_key = 'online_encoder.net.'
|
|
|
|
sdo = {}
|
2023-03-21 15:39:28 +00:00
|
|
|
for k, v in sd.items():
|
2020-12-10 16:57:52 +00:00
|
|
|
if wrap_key in k:
|
|
|
|
sdo[k.replace(wrap_key, '')] = v
|
2021-08-06 04:15:26 +00:00
|
|
|
return sdo
|
|
|
|
|
2023-03-21 15:39:28 +00:00
|
|
|
|
2021-08-06 04:15:26 +00:00
|
|
|
if __name__ == '__main__':
|
|
|
|
pretrained_path = '../../../experiments/uresnet_pixpro4_imgset.pth'
|
|
|
|
output_path = '../../../experiments/uresnet_pixpro4_imgset.pth'
|
|
|
|
|
|
|
|
sd = torch.load(pretrained_path)
|
|
|
|
sd = extract_byol_model_from_state_dict(sd)
|
2020-12-10 16:57:52 +00:00
|
|
|
|
2023-03-21 15:39:28 +00:00
|
|
|
# model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda')
|
|
|
|
# model.load_state_dict(sdo, strict=True)
|
2020-12-10 16:57:52 +00:00
|
|
|
|
|
|
|
print("Validation succeeded, dumping state dict to output path.")
|
2023-03-21 15:39:28 +00:00
|
|
|
torch.save(sdo, output_path)
|