Misc
This commit is contained in:
parent
6c0e9f45c7
commit
6c27ddc9b5
|
@ -5,17 +5,17 @@ import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/use_video_upsample.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/use_video_upsample.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
|
|
||||||
netG = define_G(opt)
|
netG = define_G(opt)
|
||||||
dummyInput = torch.rand(1,3,8,8)
|
dummyInput = torch.rand(1,3,8,8)
|
||||||
|
|
||||||
mode = 'torchscript'
|
mode = 'trace'
|
||||||
if mode == 'torchscript':
|
if mode == 'torchscript':
|
||||||
print("Tracing generator network..")
|
print("Tracing generator network..")
|
||||||
traced_netG = torch.jit.trace(netG, dummyInput)
|
traced_netG = torch.jit.trace(netG, dummyInput)
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import onnxruntime
|
import onnx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
|
|
||||||
session = onnxruntime.InferenceSession("../results/gen.onnx")
|
model = onnx.load('../results/gen.onnx')
|
||||||
v = np.random.randn(1,3,1700,1500)
|
|
||||||
st = time.time()
|
outputs = {}
|
||||||
prediction = session.run(None, {"lr_input": v.astype(np.float32)})
|
for n in model.graph.node:
|
||||||
print("Took %f" % (time.time() - st))
|
for o in n.output:
|
||||||
print(prediction[0].shape)
|
outputs[o] = n
|
||||||
|
|
||||||
|
res = 0
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_switched_rrdb_small.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_lowdim_rrdb_no_sr.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user