From 6400607fc5349156adcd49bf1a8163939258ef6d Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 19 May 2020 09:36:04 -0600 Subject: [PATCH] ONNX export support --- codes/distill_torchscript.py | 22 ++++++++++++++++------ codes/onnx_inference.py | 10 ++++++++++ 2 files changed, 26 insertions(+), 6 deletions(-) create mode 100644 codes/onnx_inference.py diff --git a/codes/distill_torchscript.py b/codes/distill_torchscript.py index 07f2ecd8..8a4ea1cc 100644 --- a/codes/distill_torchscript.py +++ b/codes/distill_torchscript.py @@ -5,13 +5,23 @@ import torch if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_adrianna_full.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/use_vrp_upsample.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) netG = define_G(opt) + dummyInput = torch.rand(1,3,8,8) - print("Tracing generator network..") - dummyInput = torch.rand(1, 3, 8, 8) - traced_netG = torch.jit.trace(netG, dummyInput) - traced_netG.save('../results/traced_generator.zip') - print(traced_netG) \ No newline at end of file + torchscript = False + if torchscript: + print("Tracing generator network..") + traced_netG = torch.jit.trace(netG, dummyInput) + traced_netG.save('../results/ts_generator.zip') + print(traced_netG) + else: + print("Performing onnx trace") + input_names = ["lr_input"] + output_names = ["hr_image"] + dynamic_axes = {'lr_input': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}, 'hr_image': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}} + + torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names, + output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11) \ No newline at end of file diff --git a/codes/onnx_inference.py b/codes/onnx_inference.py new file mode 100644 index 00000000..f61a243f --- /dev/null +++ b/codes/onnx_inference.py @@ -0,0 +1,10 @@ +import onnxruntime +import numpy as np +import time + +session = onnxruntime.InferenceSession("../results/gen.onnx") +v = np.random.randn(1,3,1700,1500) +st = time.time() +prediction = session.run(None, {"lr_input": v.astype(np.float32)}) +print("Took %f" % (time.time() - st)) +print(prediction[0].shape) \ No newline at end of file