Removed a lot of legacy stuff I have no intent on using again. Plan is to shape this repo into something more extensible (get it? hah!)
177 lines
6.6 KiB
import argparse
import functools
import torch
from utils import options as option
from models.networks import define_G
class TracedModule:
def __init__(self, idname):
self.idname = idname
self.traced_outputs = []
self.traced_inputs = []
class TorchCustomTrace:
def __init__(self):
self.module_name_counter = {}
self.modules = {}
self.graph = {}
self.module_map_by_inputs = {}
self.module_map_by_outputs = {}
self.inputs_to_func_output_tuple = {}
def add_tracked_module(self, mod: torch.nn.Module):
modname = type(mod).__name__
if modname not in self.module_name_counter.keys():
self.module_name_counter[modname] = 0
self.module_name_counter[modname] += 1
idname = "%s(%03d)" % (modname, self.module_name_counter[modname])
self.modules[idname] = TracedModule(idname)
return idname
# Only called for nn.Modules since those are the only things we can access. Filling in the gaps will be done in
# the backwards pass.
def mem_forward_hook(self, module: torch.nn.Module, inputs, outputs, trace: str, mod_id: str):
mod = self.modules[mod_id]
for li in inputs:
if type(li) == torch.Tensor:
li = [li]
if type(li) == list:
for i in li:
if i.data_ptr() in self.module_map_by_inputs.keys():
self.module_map_by_inputs[i.data_ptr()] = [mod]
for o in outputs:
if o.data_ptr() in self.module_map_by_inputs.keys():
self.module_map_by_inputs[o.data_ptr()] = [mod]
def mem_backward_hook(self, inputs, outputs, op):
if len(inputs) == 0:
print("No inputs.. %s" % (op,))
outs = [o.data_ptr() for o in outputs]
tup = (outs, op)
for li in inputs:
if type(li) == torch.Tensor:
li = [li]
if type(li) == list:
for i in li:
if i.data_ptr() in self.module_map_by_inputs.keys():
print("%i: [%s] {%s}" % (i.data_ptr(), op, [n.idname for n in self.module_map_by_inputs[i.data_ptr()]]))
if i.data_ptr() in self.inputs_to_func_output_tuple.keys():
self.inputs_to_func_output_tuple[i.data_ptr()] = [tup]
def install_hooks(self, mod: torch.nn.Module, trace=""):
mod_id = self.add_tracked_module(mod)
my_trace = trace + "->" + mod_id
# If this module has parameters, it also has a state worth tracking.
#if next(mod.parameters(recurse=False), None) is not None:
mod.register_forward_hook(functools.partial(self.mem_forward_hook, trace=my_trace, mod_id=mod_id))
for m in mod.children():
self.install_hooks(m, my_trace)
def install_backward_hooks(self, grad_fn):
# AccumulateGrad simply pushes a gradient into the specified variable, and isn't useful for the purposes of
# tracing the graph.
if grad_fn is None or "AccumulateGrad" in str(grad_fn):
grad_fn.register_hook(functools.partial(self.mem_backward_hook, op=str(grad_fn)))
for g, _ in grad_fn.next_functions:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_div2k_pixgan_srg2.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
netG = define_G(opt)
dummyInput = torch.rand(1,3,32,32)
mode = 'onnx'
if mode == 'torchscript':
print("Tracing generator network..")
traced_netG = torch.jit.trace(netG, dummyInput)
for i, module in enumerate(traced_netG.RRDB_trunk.modules()):
print(i, str(module))
elif mode == 'onnx':
print("Performing onnx trace")
input_names = ["lr_input"]
output_names = ["hr_image"]
dynamic_axes = {'lr_input': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}, 'hr_image': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}}
torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names,
output_names=output_names, dynamic_axes=dynamic_axes, opset_version=12)
elif mode == 'memtrace':
criterion = torch.nn.MSELoss()
tracer = TorchCustomTrace()
out, = netG(dummyInput)
target = torch.zeros_like(out)
loss = criterion(out, target)
elif mode == 'trace':
out = netG.forward(dummyInput)[0]
# Build the graph backwards.
graph = build_graph(out, 'output')
def get_unique_id_for_fn(fn):
return (str(fn).split(" object at ")[1])[:-1]
class GraphNode:
def __init__(self, fn):
| = (str(fn).split(" object at ")[0])[1:]
self.fn = fn
self.children = {}
self.parents = {}
def add_parent(self, parent):
self.parents[get_unique_id_for_fn(parent)] = parent
def add_child(self, child):
self.children[get_unique_id_for_fn(child)] = child
class TorchGraph:
def __init__(self):
self.tensor_map = {}
def get_node_for_tensor(self, t):
return self.tensor_map[get_unique_id_for_fn(t)]
def init(self, output_tensor):
self.build_graph_backwards(output_tensor.grad_fn, None)
# Find inputs
self.inputs = []
for v in self.tensor_map.values():
# Is an input if the parents dict is empty.
if bool(v.parents):
def build_graph_backwards(self, fn, previous_fn):
id = get_unique_id_for_fn(fn)
if id in self.tensor_map:
node = self.tensor_map[id]
node = GraphNode(fn)
self.tensor_map[id] = node
# Propagate to children
for child_fn in fn.next_functions:
node.add_parent(self.build_graph_backwards(child_fn, fn))
return node |