This commit is contained in:
James Betker 2020-06-14 11:03:02 -06:00
parent 6c0e9f45c7
commit 6c27ddc9b5
3 changed files with 12 additions and 10 deletions

View File

@ -5,17 +5,17 @@ import torch
import torchvision
import torch.nn.functional as F
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/use_video_upsample.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
netG = define_G(opt)
dummyInput = torch.rand(1,3,8,8)
mode = 'torchscript'
mode = 'trace'
if mode == 'torchscript':
print("Tracing generator network..")
traced_netG = torch.jit.trace(netG, dummyInput)

View File

@ -1,10 +1,12 @@
import onnxruntime
import onnx
import numpy as np
import time
session = onnxruntime.InferenceSession("../results/gen.onnx")
v = np.random.randn(1,3,1700,1500)
st = time.time()
prediction = session.run(None, {"lr_input": v.astype(np.float32)})
print("Took %f" % (time.time() - st))
print(prediction[0].shape)
model = onnx.load('../results/gen.onnx')
outputs = {}
for n in model.graph.node:
for o in n.output:
outputs[o] = n
res = 0

View File

@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_switched_rrdb_small.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_lowdim_rrdb_no_sr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)