DL-Art-School/codes/scripts/transfer_params_MSRResNet.py
XintaoWang 037933ba66 mmsr
2019-08-23 21:42:47 +08:00

28 lines
1.1 KiB
Python

import os.path as osp
import sys
import torch
try:
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
import models.archs.SRResNet_arch as SRResNet_arch
except ImportError:
pass
pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth')
crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
crt_net = crt_model.state_dict()
for k, v in crt_net.items():
if k in pretrained_net and 'upconv1' not in k:
crt_net[k] = pretrained_net[k]
print('replace ... ', k)
# x4 -> x3
crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2
torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')