2020-05-13 15:21:13 +00:00
|
|
|
import argparse
|
|
|
|
import options.options as option
|
|
|
|
from models.networks import define_G
|
|
|
|
import torch
|
2020-05-24 03:09:38 +00:00
|
|
|
import torchvision
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
2020-05-13 15:21:13 +00:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-05-24 03:09:38 +00:00
|
|
|
|
2020-05-13 15:21:13 +00:00
|
|
|
parser = argparse.ArgumentParser()
|
2020-06-02 15:35:52 +00:00
|
|
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/use_video_upsample.yml')
|
2020-05-13 15:21:13 +00:00
|
|
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
|
|
|
opt = option.dict_to_nonedict(opt)
|
|
|
|
netG = define_G(opt)
|
2020-05-19 15:36:04 +00:00
|
|
|
dummyInput = torch.rand(1,3,8,8)
|
2020-05-13 15:21:13 +00:00
|
|
|
|
2020-06-06 03:01:50 +00:00
|
|
|
mode = 'torchscript'
|
|
|
|
if mode == 'torchscript':
|
2020-05-19 15:36:04 +00:00
|
|
|
print("Tracing generator network..")
|
|
|
|
traced_netG = torch.jit.trace(netG, dummyInput)
|
|
|
|
traced_netG.save('../results/ts_generator.zip')
|
2020-06-06 03:01:50 +00:00
|
|
|
|
|
|
|
print(traced_netG.code)
|
|
|
|
for i, module in enumerate(traced_netG.RRDB_trunk.modules()):
|
|
|
|
print(i, str(module))
|
|
|
|
elif mode == 'onnx':
|
2020-05-19 15:36:04 +00:00
|
|
|
print("Performing onnx trace")
|
|
|
|
input_names = ["lr_input"]
|
|
|
|
output_names = ["hr_image"]
|
|
|
|
dynamic_axes = {'lr_input': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}, 'hr_image': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}}
|
|
|
|
|
|
|
|
torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names,
|
2020-06-06 03:01:50 +00:00
|
|
|
output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11)
|
|
|
|
elif mode == 'trace':
|
|
|
|
out = netG.forward(dummyInput)[0]
|
|
|
|
print(out.shape)
|
|
|
|
# Build the graph backwards.
|
|
|
|
graph = build_graph(out, 'output')
|
|
|
|
|
|
|
|
def get_unique_id_for_fn(fn):
|
|
|
|
return (str(fn).split(" object at ")[1])[:-1]
|
|
|
|
|
|
|
|
class GraphNode:
|
|
|
|
def __init__(self, fn):
|
|
|
|
self.name = (str(fn).split(" object at ")[0])[1:]
|
|
|
|
self.fn = fn
|
|
|
|
self.children = {}
|
|
|
|
self.parents = {}
|
|
|
|
|
|
|
|
def add_parent(self, parent):
|
|
|
|
self.parents[get_unique_id_for_fn(parent)] = parent
|
|
|
|
|
|
|
|
def add_child(self, child):
|
|
|
|
self.children[get_unique_id_for_fn(child)] = child
|
|
|
|
|
|
|
|
class TorchGraph:
|
|
|
|
def __init__(self):
|
|
|
|
self.tensor_map = {}
|
|
|
|
|
|
|
|
def get_node_for_tensor(self, t):
|
|
|
|
return self.tensor_map[get_unique_id_for_fn(t)]
|
|
|
|
|
|
|
|
def init(self, output_tensor):
|
|
|
|
self.build_graph_backwards(output_tensor.grad_fn, None)
|
|
|
|
# Find inputs
|
|
|
|
self.inputs = []
|
|
|
|
for v in self.tensor_map.values():
|
|
|
|
# Is an input if the parents dict is empty.
|
|
|
|
if bool(v.parents):
|
|
|
|
self.inputs.append(v)
|
|
|
|
|
|
|
|
def build_graph_backwards(self, fn, previous_fn):
|
|
|
|
id = get_unique_id_for_fn(fn)
|
|
|
|
if id in self.tensor_map:
|
|
|
|
node = self.tensor_map[id]
|
|
|
|
node.add_child(previous_fn)
|
|
|
|
else:
|
|
|
|
node = GraphNode(fn)
|
|
|
|
self.tensor_map[id] = node
|
|
|
|
# Propagate to children
|
|
|
|
for child_fn in fn.next_functions:
|
|
|
|
node.add_parent(self.build_graph_backwards(child_fn, fn))
|
|
|
|
return node
|