import os.path as osp
import sys
import torch
try:
    sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
    import models.archs.SRResNet_arch as SRResNet_arch
except ImportError:
    pass

pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth')
crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
crt_net = crt_model.state_dict()

for k, v in crt_net.items():
    if k in pretrained_net and 'upconv1' not in k:
        crt_net[k] = pretrained_net[k]
        print('replace ... ', k)

# x4 -> x3
crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2

torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')