From 6f2bc36c6108e976b0068a9e7b046d35f34b5797 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 27 Jun 2020 08:28:09 -0600 Subject: [PATCH] Distill_torchscript mods Starts down the path of writing a custom trace that works using torch's hook mechanism. --- codes/distill_torchscript.py | 103 +++++++++++++++++++++++++++++++++-- 1 file changed, 97 insertions(+), 6 deletions(-) diff --git a/codes/distill_torchscript.py b/codes/distill_torchscript.py index b10373cd..dd23c3f5 100644 --- a/codes/distill_torchscript.py +++ b/codes/distill_torchscript.py @@ -1,9 +1,91 @@ import argparse +import functools +import torch import options.options as option from models.networks import define_G -import torch -import torchvision -import torch.nn.functional as F + + +class TracedModule: + def __init__(self, idname): + self.idname = idname + self.traced_outputs = [] + self.traced_inputs = [] + + +class TorchCustomTrace: + def __init__(self): + self.module_name_counter = {} + self.modules = {} + self.graph = {} + self.module_map_by_inputs = {} + self.module_map_by_outputs = {} + self.inputs_to_func_output_tuple = {} + + def add_tracked_module(self, mod: torch.nn.Module): + modname = type(mod).__name__ + if modname not in self.module_name_counter.keys(): + self.module_name_counter[modname] = 0 + self.module_name_counter[modname] += 1 + idname = "%s(%03d)" % (modname, self.module_name_counter[modname]) + self.modules[idname] = TracedModule(idname) + return idname + + # Only called for nn.Modules since those are the only things we can access. Filling in the gaps will be done in + # the backwards pass. + def mem_forward_hook(self, module: torch.nn.Module, inputs, outputs, trace: str, mod_id: str): + mod = self.modules[mod_id] + for li in inputs: + if type(li) == torch.Tensor: + li = [li] + if type(li) == list: + for i in li: + if i.data_ptr() in self.module_map_by_inputs.keys(): + self.module_map_by_inputs[i.data_ptr()].append(mod) + else: + self.module_map_by_inputs[i.data_ptr()] = [mod] + for o in outputs: + if o.data_ptr() in self.module_map_by_inputs.keys(): + self.module_map_by_inputs[o.data_ptr()].append(mod) + else: + self.module_map_by_inputs[o.data_ptr()] = [mod] + # print(trace, [i.data_ptr() for i in inputs], [o.data_ptr() for o in outputs]) + + def mem_backward_hook(self, inputs, outputs, op): + if len(inputs) == 0: + print("No inputs.. %s" % (op,)) + outs = [o.data_ptr() for o in outputs] + tup = (outs, op) + #print(tup) + for li in inputs: + if type(li) == torch.Tensor: + li = [li] + if type(li) == list: + for i in li: + if i.data_ptr() in self.module_map_by_inputs.keys(): + print("%i: [%s] {%s}" % (i.data_ptr(), op, [n.idname for n in self.module_map_by_inputs[i.data_ptr()]])) + if i.data_ptr() in self.inputs_to_func_output_tuple.keys(): + self.inputs_to_func_output_tuple[i.data_ptr()].append(tup) + else: + self.inputs_to_func_output_tuple[i.data_ptr()] = [tup] + + def install_hooks(self, mod: torch.nn.Module, trace=""): + mod_id = self.add_tracked_module(mod) + my_trace = trace + "->" + mod_id + # If this module has parameters, it also has a state worth tracking. + #if next(mod.parameters(recurse=False), None) is not None: + mod.register_forward_hook(functools.partial(self.mem_forward_hook, trace=my_trace, mod_id=mod_id)) + + for m in mod.children(): + self.install_hooks(m, my_trace) + + def install_backward_hooks(self, grad_fn): + # AccumulateGrad simply pushes a gradient into the specified variable, and isn't useful for the purposes of + # tracing the graph. + if grad_fn is None or "AccumulateGrad" in str(grad_fn): + return + grad_fn.register_hook(functools.partial(self.mem_backward_hook, op=str(grad_fn))) + for g, _ in grad_fn.next_functions: + self.install_backward_hooks(g) if __name__ == "__main__": @@ -13,9 +95,9 @@ if __name__ == "__main__": opt = option.dict_to_nonedict(opt) netG = define_G(opt) - dummyInput = torch.rand(1,3,8,8) + dummyInput = torch.rand(1,3,32,32) - mode = 'trace' + mode = 'memtrace' if mode == 'torchscript': print("Tracing generator network..") traced_netG = torch.jit.trace(netG, dummyInput) @@ -31,7 +113,16 @@ if __name__ == "__main__": dynamic_axes = {'lr_input': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}, 'hr_image': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}} torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names, - output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11) + output_names=output_names, dynamic_axes=dynamic_axes, opset_version=12) + elif mode == 'memtrace': + criterion = torch.nn.MSELoss() + tracer = TorchCustomTrace() + tracer.install_hooks(netG) + out, = netG(dummyInput) + tracer.install_backward_hooks(out.grad_fn) + target = torch.zeros_like(out) + loss = criterion(out, target) + loss.backward() elif mode == 'trace': out = netG.forward(dummyInput)[0] print(out.shape)