From 9ead2c0a08eeb1fbebde53bf3662bde3a4e9c83c Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 17 Oct 2020 22:54:12 -0600 Subject: [PATCH] Multiscale training in! --- codes/data/__init__.py | 2 + codes/data/multiscale_dataset.py | 23 +++-- codes/models/archs/ChainedEmbeddingGen.py | 15 ++- codes/models/networks.py | 7 +- codes/models/steps/injectors.py | 4 +- codes/models/steps/losses.py | 2 +- codes/models/steps/progressive_zoom.py | 120 ++++++++++++++++++++++ codes/train.py | 2 +- codes/train2.py | 2 +- 9 files changed, 159 insertions(+), 18 deletions(-) create mode 100644 codes/models/steps/progressive_zoom.py diff --git a/codes/data/__init__.py b/codes/data/__init__.py index b08d7d9f..5996436c 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -37,6 +37,8 @@ def create_dataset(dataset_opt): from data.multi_frame_dataset import MultiFrameDataset as D elif mode == 'combined': from data.combined_dataset import CombinedDataset as D + elif mode == 'multiscale': + from data.multiscale_dataset import MultiScaleDataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) diff --git a/codes/data/multiscale_dataset.py b/codes/data/multiscale_dataset.py index 4a51528c..0a3dc0b4 100644 --- a/codes/data/multiscale_dataset.py +++ b/codes/data/multiscale_dataset.py @@ -19,7 +19,7 @@ class MultiScaleDataset(data.Dataset): self.num_scales = self.opt['num_scales'] self.hq_size_cap = self.tile_size * 2 ** self.num_scales self.scale = self.opt['scale'] - self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['dataroot'], [1]) + self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1]) # Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping # offset from center is chosen on a normal probability curve. @@ -43,7 +43,7 @@ class MultiScaleDataset(data.Dataset): if depth >= self.num_scales: return patch_size = self.hq_size_cap // (2 ** depth) - # First pull the four sub-patches. + # First pull the four sub-patches. Important: if this is changed, be sure to edit build_multiscale_patch_index_map() below. patches = [input_img[:patch_size, :patch_size], input_img[:patch_size, patch_size:], input_img[patch_size:, :patch_size], @@ -67,20 +67,29 @@ class MultiScaleDataset(data.Dataset): if patches_hq[0].shape[2] == 3: patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq] patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq] + patches_hq = torch.stack(patches_hq, dim=0) patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='bilinear').squeeze() for p in patches_hq] + patches_lq = torch.stack(patches_lq, dim=0) - d = {'LQ': patches_lq, 'HQ': patches_hq, 'GT_path': full_path} + d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path} return d def __len__(self): return len(self.paths_hq) class MultiscaleTreeNode: - def __init__(self, index, parent): + def __init__(self, index, parent, i): self.index = index self.parent = parent self.children = [] + # These represent the offset from left and top of the image for the individual patch as a proportion of the entire image. + # Tightly tied to the implementation above for the order in which the patches are pulled from the base image. + lefts = [0, .5, 0, .5] + tops = [0, 0, .5, .5] + self.left = lefts[i] + self.top = tops[i] + def add_child(self, child): self.children.append(child) return child @@ -89,14 +98,14 @@ class MultiscaleTreeNode: def build_multiscale_patch_index_map(depth): if depth < 0: return - root = MultiscaleTreeNode(0, None) + root = MultiscaleTreeNode(0, None, 0) leaves = [] _build_multiscale_patch_index_map(depth-1, 1, root, leaves) return leaves def _build_multiscale_patch_index_map(depth, ind, node, leaves): - subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node)) for i in range(4)] + subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node, i)) for i in range(4)] ind += 4 if depth == 1: leaves.extend(subnodes) @@ -109,7 +118,7 @@ def _build_multiscale_patch_index_map(depth, ind, node, leaves): if __name__ == '__main__': opt = { 'name': 'amalgam', - 'dataroot': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images'], + 'dataroot': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half'], 'num_scales': 4, 'scale': 2, 'hq_tile_size': 128 diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index 63dfbce4..d82f00d7 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -66,12 +66,12 @@ class ChainedEmbeddingGen(nn.Module): class ChainedEmbeddingGenWithStructure(nn.Module): - def __init__(self, depth=10, recurrent=False): + def __init__(self, depth=10, recurrent=False, recurrent_nf=3, recurrent_stride=2): super(ChainedEmbeddingGenWithStructure, self).__init__() self.recurrent = recurrent self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) if recurrent: - self.recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False) + self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False) self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) @@ -80,12 +80,16 @@ class ChainedEmbeddingGenWithStructure(nn.Module): self.structure_upsample = FinalUpsampleBlock2x(64) self.grad_extract = ImageGradientNoPadding() self.upsample = FinalUpsampleBlock2x(64) + self.ref_join_std = 0 def forward(self, x, recurrent=None): fea = self.initial_conv(x) if self.recurrent: + if recurrent is None: + recurrent = torch.zeros_like(fea) rec = self.recurrent_process(recurrent) - fea, _ = self.recurrent_join(fea, rec) + fea, recstd = self.recurrent_join(fea, rec) + self.ref_join_std = recstd.item() emb = checkpoint(self.spine, fea) grad = fea for i, block in enumerate(self.blocks): @@ -94,4 +98,7 @@ class ChainedEmbeddingGenWithStructure(nn.Module): structure_br = checkpoint(self.structure_joins[i], grad, fea) grad = grad + checkpoint(self.structure_blocks[i], structure_br) out = checkpoint(self.upsample, fea) - return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out) + return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea + + def get_debug_values(self, step, net_name): + return { 'ref_join_std': self.ref_join_std } \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index e2d4aaf6..c12f20c5 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -126,9 +126,10 @@ def define_G(opt, net_key='network_G', scale=None): elif which_model == 'chained_gen': netG = ChainedEmbeddingGen(depth=opt_net['depth']) elif which_model == 'chained_gen_structured': - netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False) - elif which_model == 'chained_gen_structuredr2': - netG = ChainedEmbeddingGenWithStructureR2(depth=opt_net['depth']) + rec = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False + recnf = opt_net['recurrent_nf'] if 'recurrent_nf' in opt_net.keys() else 3 + recstd = opt_net['recurrent_stride'] if 'recurrent_stride' in opt_net.keys() else 2 + netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd) elif which_model == "flownet2": from models.flownet2.models import FlowNet2 ld = torch.load(opt_net['load_path']) diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 6f4642ab..dc7abff7 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -1,7 +1,6 @@ import torch.nn from models.archs.SPSR_arch import ImageGradientNoPadding from utils.weight_scheduler import get_scheduler_for_opt -#from models.steps.recursive_gen_injectors import ImageFlowInjector from models.steps.losses import extract_params_from_state # Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions. @@ -10,6 +9,9 @@ def create_injector(opt_inject, env): if 'teco_' in type: from models.steps.tecogan_losses import create_teco_injector return create_teco_injector(opt_inject, env) + elif 'progressive_' in type: + from models.steps.progressive_zoom import create_progressive_zoom_injector + return create_progressive_zoom_injector(opt_inject, env) elif type == 'generator': return ImageGeneratorInjector(opt_inject, env) elif type == 'discriminator': diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 0e37e2c7..72a3dfee 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -37,7 +37,7 @@ def create_loss(opt_loss, env): # Converts params to a list of tensors extracted from state. Works with list/tuple params as well as scalars. -def extract_params_from_state(params, state, root=True): +def extract_params_from_state(params: object, state: object, root: object = True) -> object: if isinstance(params, list) or isinstance(params, tuple): p = [extract_params_from_state(r, state, False) for r in params] elif isinstance(params, str): diff --git a/codes/models/steps/progressive_zoom.py b/codes/models/steps/progressive_zoom.py new file mode 100644 index 00000000..1c3b9d43 --- /dev/null +++ b/codes/models/steps/progressive_zoom.py @@ -0,0 +1,120 @@ +import os +import random + +import torch +import torchvision + +from data.multiscale_dataset import build_multiscale_patch_index_map +from models.steps.injectors import Injector +from models.steps.losses import extract_params_from_state +from models.steps.tecogan_losses import extract_inputs_index +import os.path as osp + + +def create_progressive_zoom_injector(opt, env): + type = opt['type'] + if type == 'progressive_zoom_generator': + return ProgressiveGeneratorInjector(opt, env) + return None + + +class ProgressiveGeneratorInjector(Injector): + def __init__(self, opt, env): + super(ProgressiveGeneratorInjector, self).__init__(opt, env) + self.gen_key = opt['generator'] + self.hq_key = opt['hq'] # The key where HQ images are stored. + self.hq_output_key = opt['hq_output'] # The key where HQ images corresponding with generated images are stored. + self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0 + self.output_hq_index = opt['output_hq_index'] + self.recurrent_output_index = opt['recurrent_output_index'] + self.recurrent_index = opt['recurrent_index'] + self.depth = opt['depth'] + self.number_branches = opt['num_branches'] # Number of input branches to randomly choose for generation. This defines the output shape. + self.multiscale_leaves = build_multiscale_patch_index_map(self.depth) + self.feed_gen_output_into_input = opt['feed_gen_output_into_input'] + + # Given a set of multiscale inputs, selects self.num_branches leaves and produces that many chains of inputs, + # excluding the base input for efficiency reasons. + # Output is a list of chains. Each chain is itself a list of links. Each link is MultiscaleTreeNode + def get_input_chains(self): + leaves = random.sample(self.multiscale_leaves, self.number_branches) + chains = [] + for leaf in leaves: + chain = [leaf] + node = leaf.parent + while node.parent is not None: + chain.insert(0, node) + node = node.parent + chains.append(chain) + return chains + + def feed_forward(self, gen, inputs, results, lq_input, recurrent_input): + ff_input = inputs.copy() + ff_input[self.input_lq_index] = lq_input + ff_input[self.recurrent_index] = recurrent_input + gen_out = gen(*ff_input) + if isinstance(gen_out, torch.Tensor): + gen_out = [gen_out] + for i, out_key in enumerate(self.output): + results[out_key].append(gen_out[i]) + return gen_out[self.output_hq_index], gen_out[self.recurrent_output_index] + + def forward(self, state): + gen = self.env['generators'][self.gen_key] + inputs = extract_params_from_state(self.input, state) + lq_inputs = inputs[self.input_lq_index] + hq_inputs = state[self.hq_key] + if not isinstance(inputs, list): + inputs = [inputs] + if not isinstance(self.output, list): + self.output = [self.output] + results = {} # A list of outputs produced by feeding each progressive lq input into the generator. + results_hq = [] + for out_key in self.output: + results[out_key] = [] + + b, f, h, w = lq_inputs[:, 0].shape + base_hq_out, base_recurrent = self.feed_forward(gen, inputs, results, lq_inputs[:, 0], None) + results_hq.append(hq_inputs[:, 0]) + input_chains = self.get_input_chains() + debug_index = 0 + for chain in input_chains: + chain_input = [lq_inputs[:, 0]] + chain_output = [base_hq_out] + recurrent_hq = base_hq_out + recurrent = base_recurrent + for link in chain: # Remember, `link` is a MultiscaleTreeNode. + if self.feed_gen_output_into_input: + top = int(link.top * 2 * h) + left = int(link.left * 2 * w) + lq_input = recurrent_hq[:, :, top:top+h, left:left+w] + else: + lq_input = lq_inputs[:, link.index] + chain_input.append(lq_input) + recurrent_hq, recurrent = self.feed_forward(gen, inputs, results, lq_input, recurrent) + chain_output.append(recurrent_hq) + results_hq.append(hq_inputs[:, link.index]) + + if self.env['step'] % 1 == 0: + self.produce_progressive_visual_debugs(chain_input, chain_output, debug_index) + debug_index += 1 + results[self.hq_output_key] = results_hq + for k, v in results.items(): + results[k] = torch.stack(v, dim=1) + return results + + + def produce_progressive_visual_debugs(self, chain_inputs, chain_outputs, it): + if self.env['rank'] > 0: + return + if self.feed_gen_output_into_input: + lbl = 'generator_recurrent' + else: + lbl = 'generator_regular' + base_path = osp.join(self.env['base_path'], "..", "visual_dbg", lbl, str(self.env['step'])) + os.makedirs(base_path, exist_ok=True) + ind = 1 + for i, o in zip(chain_inputs, chain_outputs): + torchvision.utils.save_image(i, osp.join(base_path, "%s_%i_input.png" % (it, ind))) + torchvision.utils.save_image(o, osp.join(base_path, "%s_%i_output.png" % (it, ind))) + ind += 1 diff --git a/codes/train.py b/codes/train.py index de9ce78d..0c7de24b 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structuredr2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_chained.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/train2.py b/codes/train2.py index f38eee3e..f52ec9d3 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_constrained.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_chained_structured.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()