Misc
This commit is contained in:
parent
6c0e9f45c7
commit
6c27ddc9b5
|
@ -5,17 +5,17 @@ import torch
|
|||
import torchvision
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/use_video_upsample.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
|
||||
netG = define_G(opt)
|
||||
dummyInput = torch.rand(1,3,8,8)
|
||||
|
||||
mode = 'torchscript'
|
||||
mode = 'trace'
|
||||
if mode == 'torchscript':
|
||||
print("Tracing generator network..")
|
||||
traced_netG = torch.jit.trace(netG, dummyInput)
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import onnxruntime
|
||||
import onnx
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
session = onnxruntime.InferenceSession("../results/gen.onnx")
|
||||
v = np.random.randn(1,3,1700,1500)
|
||||
st = time.time()
|
||||
prediction = session.run(None, {"lr_input": v.astype(np.float32)})
|
||||
print("Took %f" % (time.time() - st))
|
||||
print(prediction[0].shape)
|
||||
model = onnx.load('../results/gen.onnx')
|
||||
|
||||
outputs = {}
|
||||
for n in model.graph.node:
|
||||
for o in n.output:
|
||||
outputs[o] = n
|
||||
|
||||
res = 0
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_switched_rrdb_small.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_lowdim_rrdb_no_sr.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user