Allow uresnet to use pretrained resnet50
This commit is contained in:
parent
4119cd6240
commit
f3db381fa1
|
@ -14,11 +14,7 @@ import torchvision
|
|||
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
|
||||
model_urls = {
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
}
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
class ReverseBottleneck(nn.Module):
|
||||
|
@ -143,6 +139,9 @@ class UResNet50(torchvision.models.resnet.ResNet):
|
|||
@register_model
|
||||
def register_u_resnet50(opt_net, opt):
|
||||
model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
|
||||
if opt_get(opt_net, ['use_pretrained_base'], False):
|
||||
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user