28 lines
1.1 KiB
Python
28 lines
1.1 KiB
Python
|
import os.path as osp
|
||
|
import sys
|
||
|
import torch
|
||
|
try:
|
||
|
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||
|
import models.archs.SRResNet_arch as SRResNet_arch
|
||
|
except ImportError:
|
||
|
pass
|
||
|
|
||
|
pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth')
|
||
|
crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
|
||
|
crt_net = crt_model.state_dict()
|
||
|
|
||
|
for k, v in crt_net.items():
|
||
|
if k in pretrained_net and 'upconv1' not in k:
|
||
|
crt_net[k] = pretrained_net[k]
|
||
|
print('replace ... ', k)
|
||
|
|
||
|
# x4 -> x3
|
||
|
crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
|
||
|
crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
|
||
|
crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
|
||
|
crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
|
||
|
crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
|
||
|
crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2
|
||
|
|
||
|
torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')
|