Allow uresnet to use pretrained resnet50

This commit is contained in:
James Betker 2021-01-10 12:57:31 -07:00
parent 4119cd6240
commit f3db381fa1

View File

@ -14,11 +14,7 @@ import torchvision
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint, opt_get
model_urls = {
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}
class ReverseBottleneck(nn.Module): class ReverseBottleneck(nn.Module):
@ -143,6 +139,9 @@ class UResNet50(torchvision.models.resnet.ResNet):
@register_model @register_model
def register_u_resnet50(opt_net, opt): def register_u_resnet50(opt_net, opt):
model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim']) model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
if opt_get(opt_net, ['use_pretrained_base'], False):
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
model.load_state_dict(state_dict, strict=False)
return model return model