From f3db381fa1a486ea9027e9c4a32e34f5edf69851 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 10 Jan 2021 12:57:31 -0700 Subject: [PATCH] Allow uresnet to use pretrained resnet50 --- .../pixel_level_contrastive_learning/resnet_unet.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/codes/models/pixel_level_contrastive_learning/resnet_unet.py b/codes/models/pixel_level_contrastive_learning/resnet_unet.py index 38d0227c..46bc747f 100644 --- a/codes/models/pixel_level_contrastive_learning/resnet_unet.py +++ b/codes/models/pixel_level_contrastive_learning/resnet_unet.py @@ -14,11 +14,7 @@ import torchvision from trainer.networks import register_model -from utils.util import checkpoint - -model_urls = { - 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', -} +from utils.util import checkpoint, opt_get class ReverseBottleneck(nn.Module): @@ -143,6 +139,9 @@ class UResNet50(torchvision.models.resnet.ResNet): @register_model def register_u_resnet50(opt_net, opt): model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim']) + if opt_get(opt_net, ['use_pretrained_base'], False): + state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True) + model.load_state_dict(state_dict, strict=False) return model