From ef5d8a0ed19911be44716029ab50bc0ea63e1896 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 5 Jun 2020 21:01:50 -0600 Subject: [PATCH] Misc --- codes/data_scripts/extract_subimages.py | 7 +-- codes/distill_torchscript.py | 63 +++++++++++++++++++++++-- codes/train.py | 2 +- 3 files changed, 61 insertions(+), 11 deletions(-) diff --git a/codes/data_scripts/extract_subimages.py b/codes/data_scripts/extract_subimages.py index 2bb0c865..d7d4f8b2 100644 --- a/codes/data_scripts/extract_subimages.py +++ b/codes/data_scripts/extract_subimages.py @@ -20,8 +20,8 @@ def main(): # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. if mode == 'single': - opt['input_folder'] = 'F:\\4k6k\\datasets\\imagesets\\new' - opt['save_folder'] = 'F:\\4k6k\\datasets\\imagesets\\unfiltered_tiled_2x' + opt['input_folder'] = 'F:\\4k6k\\datasets\\hands_on_hc\\images' + opt['save_folder'] = 'F:\\4k6k\\datasets\\imagesets\\tiled_512px' opt['crop_sz'] = 512 # the size of each sub-image opt['step'] = 440 # step of the sliding crop window opt['thres_sz'] = 120 # size threshold @@ -82,9 +82,6 @@ def extract_single(opt, split_img=False): if not osp.exists(save_folder): os.makedirs(save_folder) print('mkdir [{:s}] ...'.format(save_folder)) - else: - print('Folder [{:s}] already exists. Exit...'.format(save_folder)) - sys.exit(1) img_list = data_util._get_paths_from_images(input_folder) def update(arg): diff --git a/codes/distill_torchscript.py b/codes/distill_torchscript.py index fd5f7455..336cc82f 100644 --- a/codes/distill_torchscript.py +++ b/codes/distill_torchscript.py @@ -15,17 +15,70 @@ if __name__ == "__main__": netG = define_G(opt) dummyInput = torch.rand(1,3,8,8) - torchscript = False - if torchscript: + mode = 'torchscript' + if mode == 'torchscript': print("Tracing generator network..") traced_netG = torch.jit.trace(netG, dummyInput) traced_netG.save('../results/ts_generator.zip') - print(traced_netG) - else: + + print(traced_netG.code) + for i, module in enumerate(traced_netG.RRDB_trunk.modules()): + print(i, str(module)) + elif mode == 'onnx': print("Performing onnx trace") input_names = ["lr_input"] output_names = ["hr_image"] dynamic_axes = {'lr_input': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}, 'hr_image': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}} torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names, - output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11) \ No newline at end of file + output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11) + elif mode == 'trace': + out = netG.forward(dummyInput)[0] + print(out.shape) + # Build the graph backwards. + graph = build_graph(out, 'output') + +def get_unique_id_for_fn(fn): + return (str(fn).split(" object at ")[1])[:-1] + +class GraphNode: + def __init__(self, fn): + self.name = (str(fn).split(" object at ")[0])[1:] + self.fn = fn + self.children = {} + self.parents = {} + + def add_parent(self, parent): + self.parents[get_unique_id_for_fn(parent)] = parent + + def add_child(self, child): + self.children[get_unique_id_for_fn(child)] = child + +class TorchGraph: + def __init__(self): + self.tensor_map = {} + + def get_node_for_tensor(self, t): + return self.tensor_map[get_unique_id_for_fn(t)] + + def init(self, output_tensor): + self.build_graph_backwards(output_tensor.grad_fn, None) + # Find inputs + self.inputs = [] + for v in self.tensor_map.values(): + # Is an input if the parents dict is empty. + if bool(v.parents): + self.inputs.append(v) + + def build_graph_backwards(self, fn, previous_fn): + id = get_unique_id_for_fn(fn) + if id in self.tensor_map: + node = self.tensor_map[id] + node.add_child(previous_fn) + else: + node = GraphNode(fn) + self.tensor_map[id] = node + # Propagate to children + for child_fn in fn.next_functions: + node.add_parent(self.build_graph_backwards(child_fn, fn)) + return node \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index d93358e4..99b2d768 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cifar_rrdb.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_rrdb_xl_wideres.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)