forked from mrq/DL-Art-School
ONNX export support
This commit is contained in:
parent
89c71293ce
commit
6400607fc5
|
@ -5,13 +5,23 @@ import torch
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_adrianna_full.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/use_vrp_upsample.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
netG = define_G(opt)
|
||||
|
||||
print("Tracing generator network..")
|
||||
dummyInput = torch.rand(1,3,8,8)
|
||||
|
||||
torchscript = False
|
||||
if torchscript:
|
||||
print("Tracing generator network..")
|
||||
traced_netG = torch.jit.trace(netG, dummyInput)
|
||||
traced_netG.save('../results/traced_generator.zip')
|
||||
traced_netG.save('../results/ts_generator.zip')
|
||||
print(traced_netG)
|
||||
else:
|
||||
print("Performing onnx trace")
|
||||
input_names = ["lr_input"]
|
||||
output_names = ["hr_image"]
|
||||
dynamic_axes = {'lr_input': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}, 'hr_image': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}}
|
||||
|
||||
torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names,
|
||||
output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11)
|
10
codes/onnx_inference.py
Normal file
10
codes/onnx_inference.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
import onnxruntime
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
session = onnxruntime.InferenceSession("../results/gen.onnx")
|
||||
v = np.random.randn(1,3,1700,1500)
|
||||
st = time.time()
|
||||
prediction = session.run(None, {"lr_input": v.astype(np.float32)})
|
||||
print("Took %f" % (time.time() - st))
|
||||
print(prediction[0].shape)
|
Loading…
Reference in New Issue
Block a user