Misc
This commit is contained in:
parent
318a604405
commit
ef5d8a0ed1
|
@ -20,8 +20,8 @@ def main():
|
|||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||
if mode == 'single':
|
||||
opt['input_folder'] = 'F:\\4k6k\\datasets\\imagesets\\new'
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\imagesets\\unfiltered_tiled_2x'
|
||||
opt['input_folder'] = 'F:\\4k6k\\datasets\\hands_on_hc\\images'
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\imagesets\\tiled_512px'
|
||||
opt['crop_sz'] = 512 # the size of each sub-image
|
||||
opt['step'] = 440 # step of the sliding crop window
|
||||
opt['thres_sz'] = 120 # size threshold
|
||||
|
@ -82,9 +82,6 @@ def extract_single(opt, split_img=False):
|
|||
if not osp.exists(save_folder):
|
||||
os.makedirs(save_folder)
|
||||
print('mkdir [{:s}] ...'.format(save_folder))
|
||||
else:
|
||||
print('Folder [{:s}] already exists. Exit...'.format(save_folder))
|
||||
sys.exit(1)
|
||||
img_list = data_util._get_paths_from_images(input_folder)
|
||||
|
||||
def update(arg):
|
||||
|
|
|
@ -15,13 +15,16 @@ if __name__ == "__main__":
|
|||
netG = define_G(opt)
|
||||
dummyInput = torch.rand(1,3,8,8)
|
||||
|
||||
torchscript = False
|
||||
if torchscript:
|
||||
mode = 'torchscript'
|
||||
if mode == 'torchscript':
|
||||
print("Tracing generator network..")
|
||||
traced_netG = torch.jit.trace(netG, dummyInput)
|
||||
traced_netG.save('../results/ts_generator.zip')
|
||||
print(traced_netG)
|
||||
else:
|
||||
|
||||
print(traced_netG.code)
|
||||
for i, module in enumerate(traced_netG.RRDB_trunk.modules()):
|
||||
print(i, str(module))
|
||||
elif mode == 'onnx':
|
||||
print("Performing onnx trace")
|
||||
input_names = ["lr_input"]
|
||||
output_names = ["hr_image"]
|
||||
|
@ -29,3 +32,53 @@ if __name__ == "__main__":
|
|||
|
||||
torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names,
|
||||
output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11)
|
||||
elif mode == 'trace':
|
||||
out = netG.forward(dummyInput)[0]
|
||||
print(out.shape)
|
||||
# Build the graph backwards.
|
||||
graph = build_graph(out, 'output')
|
||||
|
||||
def get_unique_id_for_fn(fn):
|
||||
return (str(fn).split(" object at ")[1])[:-1]
|
||||
|
||||
class GraphNode:
|
||||
def __init__(self, fn):
|
||||
self.name = (str(fn).split(" object at ")[0])[1:]
|
||||
self.fn = fn
|
||||
self.children = {}
|
||||
self.parents = {}
|
||||
|
||||
def add_parent(self, parent):
|
||||
self.parents[get_unique_id_for_fn(parent)] = parent
|
||||
|
||||
def add_child(self, child):
|
||||
self.children[get_unique_id_for_fn(child)] = child
|
||||
|
||||
class TorchGraph:
|
||||
def __init__(self):
|
||||
self.tensor_map = {}
|
||||
|
||||
def get_node_for_tensor(self, t):
|
||||
return self.tensor_map[get_unique_id_for_fn(t)]
|
||||
|
||||
def init(self, output_tensor):
|
||||
self.build_graph_backwards(output_tensor.grad_fn, None)
|
||||
# Find inputs
|
||||
self.inputs = []
|
||||
for v in self.tensor_map.values():
|
||||
# Is an input if the parents dict is empty.
|
||||
if bool(v.parents):
|
||||
self.inputs.append(v)
|
||||
|
||||
def build_graph_backwards(self, fn, previous_fn):
|
||||
id = get_unique_id_for_fn(fn)
|
||||
if id in self.tensor_map:
|
||||
node = self.tensor_map[id]
|
||||
node.add_child(previous_fn)
|
||||
else:
|
||||
node = GraphNode(fn)
|
||||
self.tensor_map[id] = node
|
||||
# Propagate to children
|
||||
for child_fn in fn.next_functions:
|
||||
node.add_parent(self.build_graph_backwards(child_fn, fn))
|
||||
return node
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cifar_rrdb.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_rrdb_xl_wideres.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user