Script to extract models from a wrapped BYOL model
This commit is contained in:
parent
a5630d282f
commit
9c5e272a22
20
codes/scripts/byol_extract_wrapped_model.py
Normal file
20
codes/scripts/byol_extract_wrapped_model.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import torch
|
||||
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
|
||||
if __name__ == '__main__':
|
||||
pretrained_path = '../../experiments/train_byol_512unsupervised/models/117000_generator.pth'
|
||||
output_path = '../../experiments/spinenet49_imgset_byol.pth'
|
||||
|
||||
wrap_key = 'online_encoder.net.'
|
||||
sd = torch.load(pretrained_path)
|
||||
sdo = {}
|
||||
for k,v in sd.items():
|
||||
if wrap_key in k:
|
||||
sdo[k.replace(wrap_key, '')] = v
|
||||
|
||||
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda')
|
||||
model.load_state_dict(sdo, strict=True)
|
||||
|
||||
print("Validation succeeded, dumping state dict to output path.")
|
||||
torch.save(sdo, output_path)
|
Loading…
Reference in New Issue
Block a user