This commit is contained in:
James Betker 2020-06-05 21:01:50 -06:00
parent 318a604405
commit ef5d8a0ed1
3 changed files with 61 additions and 11 deletions

View File

@ -20,8 +20,8 @@ def main():
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
if mode == 'single':
opt['input_folder'] = 'F:\\4k6k\\datasets\\imagesets\\new'
opt['save_folder'] = 'F:\\4k6k\\datasets\\imagesets\\unfiltered_tiled_2x'
opt['input_folder'] = 'F:\\4k6k\\datasets\\hands_on_hc\\images'
opt['save_folder'] = 'F:\\4k6k\\datasets\\imagesets\\tiled_512px'
opt['crop_sz'] = 512 # the size of each sub-image
opt['step'] = 440 # step of the sliding crop window
opt['thres_sz'] = 120 # size threshold
@ -82,9 +82,6 @@ def extract_single(opt, split_img=False):
if not osp.exists(save_folder):
os.makedirs(save_folder)
print('mkdir [{:s}] ...'.format(save_folder))
else:
print('Folder [{:s}] already exists. Exit...'.format(save_folder))
sys.exit(1)
img_list = data_util._get_paths_from_images(input_folder)
def update(arg):

View File

@ -15,17 +15,70 @@ if __name__ == "__main__":
netG = define_G(opt)
dummyInput = torch.rand(1,3,8,8)
torchscript = False
if torchscript:
mode = 'torchscript'
if mode == 'torchscript':
print("Tracing generator network..")
traced_netG = torch.jit.trace(netG, dummyInput)
traced_netG.save('../results/ts_generator.zip')
print(traced_netG)
else:
print(traced_netG.code)
for i, module in enumerate(traced_netG.RRDB_trunk.modules()):
print(i, str(module))
elif mode == 'onnx':
print("Performing onnx trace")
input_names = ["lr_input"]
output_names = ["hr_image"]
dynamic_axes = {'lr_input': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}, 'hr_image': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}}
torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names,
output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11)
output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11)
elif mode == 'trace':
out = netG.forward(dummyInput)[0]
print(out.shape)
# Build the graph backwards.
graph = build_graph(out, 'output')
def get_unique_id_for_fn(fn):
return (str(fn).split(" object at ")[1])[:-1]
class GraphNode:
def __init__(self, fn):
self.name = (str(fn).split(" object at ")[0])[1:]
self.fn = fn
self.children = {}
self.parents = {}
def add_parent(self, parent):
self.parents[get_unique_id_for_fn(parent)] = parent
def add_child(self, child):
self.children[get_unique_id_for_fn(child)] = child
class TorchGraph:
def __init__(self):
self.tensor_map = {}
def get_node_for_tensor(self, t):
return self.tensor_map[get_unique_id_for_fn(t)]
def init(self, output_tensor):
self.build_graph_backwards(output_tensor.grad_fn, None)
# Find inputs
self.inputs = []
for v in self.tensor_map.values():
# Is an input if the parents dict is empty.
if bool(v.parents):
self.inputs.append(v)
def build_graph_backwards(self, fn, previous_fn):
id = get_unique_id_for_fn(fn)
if id in self.tensor_map:
node = self.tensor_map[id]
node.add_child(previous_fn)
else:
node = GraphNode(fn)
self.tensor_map[id] = node
# Propagate to children
for child_fn in fn.next_functions:
node.add_parent(self.build_graph_backwards(child_fn, fn))
return node

View File

@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cifar_rrdb.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_rrdb_xl_wideres.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)