From 6c27ddc9b5cfb69b3f41d643166df6006968c343 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 14 Jun 2020 11:03:02 -0600 Subject: [PATCH] Misc --- codes/distill_torchscript.py | 4 ++-- codes/onnx_inference.py | 16 +++++++++------- codes/train.py | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/codes/distill_torchscript.py b/codes/distill_torchscript.py index 336cc82f..b10373cd 100644 --- a/codes/distill_torchscript.py +++ b/codes/distill_torchscript.py @@ -5,17 +5,17 @@ import torch import torchvision import torch.nn.functional as F - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/use_video_upsample.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) + netG = define_G(opt) dummyInput = torch.rand(1,3,8,8) - mode = 'torchscript' + mode = 'trace' if mode == 'torchscript': print("Tracing generator network..") traced_netG = torch.jit.trace(netG, dummyInput) diff --git a/codes/onnx_inference.py b/codes/onnx_inference.py index f61a243f..a2f214a4 100644 --- a/codes/onnx_inference.py +++ b/codes/onnx_inference.py @@ -1,10 +1,12 @@ -import onnxruntime +import onnx import numpy as np import time -session = onnxruntime.InferenceSession("../results/gen.onnx") -v = np.random.randn(1,3,1700,1500) -st = time.time() -prediction = session.run(None, {"lr_input": v.astype(np.float32)}) -print("Took %f" % (time.time() - st)) -print(prediction[0].shape) \ No newline at end of file +model = onnx.load('../results/gen.onnx') + +outputs = {} +for n in model.graph.node: + for o in n.output: + outputs[o] = n + +res = 0 \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index 67ad9dac..d0ee97ff 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_switched_rrdb_small.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_lowdim_rrdb_no_sr.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)