15 lines
415 B
Python
15 lines
415 B
Python
import munch
|
|
import torch
|
|
|
|
from trainer.networks import register_model
|
|
|
|
|
|
@register_model
|
|
def register_flownet2(opt_net):
|
|
from models.flownet2.models import FlowNet2
|
|
ld = 'load_path' in opt_net.keys()
|
|
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
|
|
netG = FlowNet2(args)
|
|
if ld:
|
|
sd = torch.load(opt_net['load_path'])
|
|
netG.load_state_dict(sd['state_dict']) |