Allow uresnet to use pretrained resnet50
This commit is contained in:
parent
4119cd6240
commit
f3db381fa1
|
@ -14,11 +14,7 @@ import torchvision
|
||||||
|
|
||||||
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint, opt_get
|
||||||
|
|
||||||
model_urls = {
|
|
||||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ReverseBottleneck(nn.Module):
|
class ReverseBottleneck(nn.Module):
|
||||||
|
@ -143,6 +139,9 @@ class UResNet50(torchvision.models.resnet.ResNet):
|
||||||
@register_model
|
@register_model
|
||||||
def register_u_resnet50(opt_net, opt):
|
def register_u_resnet50(opt_net, opt):
|
||||||
model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
|
model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
|
||||||
|
if opt_get(opt_net, ['use_pretrained_base'], False):
|
||||||
|
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user