This commit is contained in:
James Betker 2020-06-14 11:03:02 -06:00
parent 6c0e9f45c7
commit 6c27ddc9b5
3 changed files with 12 additions and 10 deletions

View File

@ -5,17 +5,17 @@ import torch
import torchvision import torchvision
import torch.nn.functional as F import torch.nn.functional as F
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/use_video_upsample.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/use_video_upsample.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
netG = define_G(opt) netG = define_G(opt)
dummyInput = torch.rand(1,3,8,8) dummyInput = torch.rand(1,3,8,8)
mode = 'torchscript' mode = 'trace'
if mode == 'torchscript': if mode == 'torchscript':
print("Tracing generator network..") print("Tracing generator network..")
traced_netG = torch.jit.trace(netG, dummyInput) traced_netG = torch.jit.trace(netG, dummyInput)

View File

@ -1,10 +1,12 @@
import onnxruntime import onnx
import numpy as np import numpy as np
import time import time
session = onnxruntime.InferenceSession("../results/gen.onnx") model = onnx.load('../results/gen.onnx')
v = np.random.randn(1,3,1700,1500)
st = time.time() outputs = {}
prediction = session.run(None, {"lr_input": v.astype(np.float32)}) for n in model.graph.node:
print("Took %f" % (time.time() - st)) for o in n.output:
print(prediction[0].shape) outputs[o] = n
res = 0

View File

@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_switched_rrdb_small.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_lowdim_rrdb_no_sr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)