Distill_torchscript mods
Starts down the path of writing a custom trace that works using torch's hook mechanism.
This commit is contained in:
parent
db08dedfe2
commit
6f2bc36c61
|
@ -1,9 +1,91 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import functools
|
||||||
|
import torch
|
||||||
import options.options as option
|
import options.options as option
|
||||||
from models.networks import define_G
|
from models.networks import define_G
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
import torch.nn.functional as F
|
class TracedModule:
|
||||||
|
def __init__(self, idname):
|
||||||
|
self.idname = idname
|
||||||
|
self.traced_outputs = []
|
||||||
|
self.traced_inputs = []
|
||||||
|
|
||||||
|
|
||||||
|
class TorchCustomTrace:
|
||||||
|
def __init__(self):
|
||||||
|
self.module_name_counter = {}
|
||||||
|
self.modules = {}
|
||||||
|
self.graph = {}
|
||||||
|
self.module_map_by_inputs = {}
|
||||||
|
self.module_map_by_outputs = {}
|
||||||
|
self.inputs_to_func_output_tuple = {}
|
||||||
|
|
||||||
|
def add_tracked_module(self, mod: torch.nn.Module):
|
||||||
|
modname = type(mod).__name__
|
||||||
|
if modname not in self.module_name_counter.keys():
|
||||||
|
self.module_name_counter[modname] = 0
|
||||||
|
self.module_name_counter[modname] += 1
|
||||||
|
idname = "%s(%03d)" % (modname, self.module_name_counter[modname])
|
||||||
|
self.modules[idname] = TracedModule(idname)
|
||||||
|
return idname
|
||||||
|
|
||||||
|
# Only called for nn.Modules since those are the only things we can access. Filling in the gaps will be done in
|
||||||
|
# the backwards pass.
|
||||||
|
def mem_forward_hook(self, module: torch.nn.Module, inputs, outputs, trace: str, mod_id: str):
|
||||||
|
mod = self.modules[mod_id]
|
||||||
|
for li in inputs:
|
||||||
|
if type(li) == torch.Tensor:
|
||||||
|
li = [li]
|
||||||
|
if type(li) == list:
|
||||||
|
for i in li:
|
||||||
|
if i.data_ptr() in self.module_map_by_inputs.keys():
|
||||||
|
self.module_map_by_inputs[i.data_ptr()].append(mod)
|
||||||
|
else:
|
||||||
|
self.module_map_by_inputs[i.data_ptr()] = [mod]
|
||||||
|
for o in outputs:
|
||||||
|
if o.data_ptr() in self.module_map_by_inputs.keys():
|
||||||
|
self.module_map_by_inputs[o.data_ptr()].append(mod)
|
||||||
|
else:
|
||||||
|
self.module_map_by_inputs[o.data_ptr()] = [mod]
|
||||||
|
# print(trace, [i.data_ptr() for i in inputs], [o.data_ptr() for o in outputs])
|
||||||
|
|
||||||
|
def mem_backward_hook(self, inputs, outputs, op):
|
||||||
|
if len(inputs) == 0:
|
||||||
|
print("No inputs.. %s" % (op,))
|
||||||
|
outs = [o.data_ptr() for o in outputs]
|
||||||
|
tup = (outs, op)
|
||||||
|
#print(tup)
|
||||||
|
for li in inputs:
|
||||||
|
if type(li) == torch.Tensor:
|
||||||
|
li = [li]
|
||||||
|
if type(li) == list:
|
||||||
|
for i in li:
|
||||||
|
if i.data_ptr() in self.module_map_by_inputs.keys():
|
||||||
|
print("%i: [%s] {%s}" % (i.data_ptr(), op, [n.idname for n in self.module_map_by_inputs[i.data_ptr()]]))
|
||||||
|
if i.data_ptr() in self.inputs_to_func_output_tuple.keys():
|
||||||
|
self.inputs_to_func_output_tuple[i.data_ptr()].append(tup)
|
||||||
|
else:
|
||||||
|
self.inputs_to_func_output_tuple[i.data_ptr()] = [tup]
|
||||||
|
|
||||||
|
def install_hooks(self, mod: torch.nn.Module, trace=""):
|
||||||
|
mod_id = self.add_tracked_module(mod)
|
||||||
|
my_trace = trace + "->" + mod_id
|
||||||
|
# If this module has parameters, it also has a state worth tracking.
|
||||||
|
#if next(mod.parameters(recurse=False), None) is not None:
|
||||||
|
mod.register_forward_hook(functools.partial(self.mem_forward_hook, trace=my_trace, mod_id=mod_id))
|
||||||
|
|
||||||
|
for m in mod.children():
|
||||||
|
self.install_hooks(m, my_trace)
|
||||||
|
|
||||||
|
def install_backward_hooks(self, grad_fn):
|
||||||
|
# AccumulateGrad simply pushes a gradient into the specified variable, and isn't useful for the purposes of
|
||||||
|
# tracing the graph.
|
||||||
|
if grad_fn is None or "AccumulateGrad" in str(grad_fn):
|
||||||
|
return
|
||||||
|
grad_fn.register_hook(functools.partial(self.mem_backward_hook, op=str(grad_fn)))
|
||||||
|
for g, _ in grad_fn.next_functions:
|
||||||
|
self.install_backward_hooks(g)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
@ -13,9 +95,9 @@ if __name__ == "__main__":
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
|
|
||||||
netG = define_G(opt)
|
netG = define_G(opt)
|
||||||
dummyInput = torch.rand(1,3,8,8)
|
dummyInput = torch.rand(1,3,32,32)
|
||||||
|
|
||||||
mode = 'trace'
|
mode = 'memtrace'
|
||||||
if mode == 'torchscript':
|
if mode == 'torchscript':
|
||||||
print("Tracing generator network..")
|
print("Tracing generator network..")
|
||||||
traced_netG = torch.jit.trace(netG, dummyInput)
|
traced_netG = torch.jit.trace(netG, dummyInput)
|
||||||
|
@ -31,7 +113,16 @@ if __name__ == "__main__":
|
||||||
dynamic_axes = {'lr_input': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}, 'hr_image': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}}
|
dynamic_axes = {'lr_input': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}, 'hr_image': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}}
|
||||||
|
|
||||||
torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names,
|
torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names,
|
||||||
output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11)
|
output_names=output_names, dynamic_axes=dynamic_axes, opset_version=12)
|
||||||
|
elif mode == 'memtrace':
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
|
tracer = TorchCustomTrace()
|
||||||
|
tracer.install_hooks(netG)
|
||||||
|
out, = netG(dummyInput)
|
||||||
|
tracer.install_backward_hooks(out.grad_fn)
|
||||||
|
target = torch.zeros_like(out)
|
||||||
|
loss = criterion(out, target)
|
||||||
|
loss.backward()
|
||||||
elif mode == 'trace':
|
elif mode == 'trace':
|
||||||
out = netG.forward(dummyInput)[0]
|
out = netG.forward(dummyInput)[0]
|
||||||
print(out.shape)
|
print(out.shape)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user