Multiscale training in!

This commit is contained in:
James Betker 2020-10-17 22:54:12 -06:00
parent e706911c83
commit 9ead2c0a08
9 changed files with 159 additions and 18 deletions

View File

@ -37,6 +37,8 @@ def create_dataset(dataset_opt):
from data.multi_frame_dataset import MultiFrameDataset as D from data.multi_frame_dataset import MultiFrameDataset as D
elif mode == 'combined': elif mode == 'combined':
from data.combined_dataset import CombinedDataset as D from data.combined_dataset import CombinedDataset as D
elif mode == 'multiscale':
from data.multiscale_dataset import MultiScaleDataset as D
else: else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt) dataset = D(dataset_opt)

View File

@ -19,7 +19,7 @@ class MultiScaleDataset(data.Dataset):
self.num_scales = self.opt['num_scales'] self.num_scales = self.opt['num_scales']
self.hq_size_cap = self.tile_size * 2 ** self.num_scales self.hq_size_cap = self.tile_size * 2 ** self.num_scales
self.scale = self.opt['scale'] self.scale = self.opt['scale']
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['dataroot'], [1]) self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1])
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping # Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
# offset from center is chosen on a normal probability curve. # offset from center is chosen on a normal probability curve.
@ -43,7 +43,7 @@ class MultiScaleDataset(data.Dataset):
if depth >= self.num_scales: if depth >= self.num_scales:
return return
patch_size = self.hq_size_cap // (2 ** depth) patch_size = self.hq_size_cap // (2 ** depth)
# First pull the four sub-patches. # First pull the four sub-patches. Important: if this is changed, be sure to edit build_multiscale_patch_index_map() below.
patches = [input_img[:patch_size, :patch_size], patches = [input_img[:patch_size, :patch_size],
input_img[:patch_size, patch_size:], input_img[:patch_size, patch_size:],
input_img[patch_size:, :patch_size], input_img[patch_size:, :patch_size],
@ -67,20 +67,29 @@ class MultiScaleDataset(data.Dataset):
if patches_hq[0].shape[2] == 3: if patches_hq[0].shape[2] == 3:
patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq] patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq]
patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq] patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq]
patches_hq = torch.stack(patches_hq, dim=0)
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='bilinear').squeeze() for p in patches_hq] patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='bilinear').squeeze() for p in patches_hq]
patches_lq = torch.stack(patches_lq, dim=0)
d = {'LQ': patches_lq, 'HQ': patches_hq, 'GT_path': full_path} d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path}
return d return d
def __len__(self): def __len__(self):
return len(self.paths_hq) return len(self.paths_hq)
class MultiscaleTreeNode: class MultiscaleTreeNode:
def __init__(self, index, parent): def __init__(self, index, parent, i):
self.index = index self.index = index
self.parent = parent self.parent = parent
self.children = [] self.children = []
# These represent the offset from left and top of the image for the individual patch as a proportion of the entire image.
# Tightly tied to the implementation above for the order in which the patches are pulled from the base image.
lefts = [0, .5, 0, .5]
tops = [0, 0, .5, .5]
self.left = lefts[i]
self.top = tops[i]
def add_child(self, child): def add_child(self, child):
self.children.append(child) self.children.append(child)
return child return child
@ -89,14 +98,14 @@ class MultiscaleTreeNode:
def build_multiscale_patch_index_map(depth): def build_multiscale_patch_index_map(depth):
if depth < 0: if depth < 0:
return return
root = MultiscaleTreeNode(0, None) root = MultiscaleTreeNode(0, None, 0)
leaves = [] leaves = []
_build_multiscale_patch_index_map(depth-1, 1, root, leaves) _build_multiscale_patch_index_map(depth-1, 1, root, leaves)
return leaves return leaves
def _build_multiscale_patch_index_map(depth, ind, node, leaves): def _build_multiscale_patch_index_map(depth, ind, node, leaves):
subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node)) for i in range(4)] subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node, i)) for i in range(4)]
ind += 4 ind += 4
if depth == 1: if depth == 1:
leaves.extend(subnodes) leaves.extend(subnodes)
@ -109,7 +118,7 @@ def _build_multiscale_patch_index_map(depth, ind, node, leaves):
if __name__ == '__main__': if __name__ == '__main__':
opt = { opt = {
'name': 'amalgam', 'name': 'amalgam',
'dataroot': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images'], 'dataroot': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half'],
'num_scales': 4, 'num_scales': 4,
'scale': 2, 'scale': 2,
'hq_tile_size': 128 'hq_tile_size': 128

View File

@ -66,12 +66,12 @@ class ChainedEmbeddingGen(nn.Module):
class ChainedEmbeddingGenWithStructure(nn.Module): class ChainedEmbeddingGenWithStructure(nn.Module):
def __init__(self, depth=10, recurrent=False): def __init__(self, depth=10, recurrent=False, recurrent_nf=3, recurrent_stride=2):
super(ChainedEmbeddingGenWithStructure, self).__init__() super(ChainedEmbeddingGenWithStructure, self).__init__()
self.recurrent = recurrent self.recurrent = recurrent
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
if recurrent: if recurrent:
self.recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False) self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False)
self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
@ -80,12 +80,16 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
self.structure_upsample = FinalUpsampleBlock2x(64) self.structure_upsample = FinalUpsampleBlock2x(64)
self.grad_extract = ImageGradientNoPadding() self.grad_extract = ImageGradientNoPadding()
self.upsample = FinalUpsampleBlock2x(64) self.upsample = FinalUpsampleBlock2x(64)
self.ref_join_std = 0
def forward(self, x, recurrent=None): def forward(self, x, recurrent=None):
fea = self.initial_conv(x) fea = self.initial_conv(x)
if self.recurrent: if self.recurrent:
if recurrent is None:
recurrent = torch.zeros_like(fea)
rec = self.recurrent_process(recurrent) rec = self.recurrent_process(recurrent)
fea, _ = self.recurrent_join(fea, rec) fea, recstd = self.recurrent_join(fea, rec)
self.ref_join_std = recstd.item()
emb = checkpoint(self.spine, fea) emb = checkpoint(self.spine, fea)
grad = fea grad = fea
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
@ -94,4 +98,7 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
structure_br = checkpoint(self.structure_joins[i], grad, fea) structure_br = checkpoint(self.structure_joins[i], grad, fea)
grad = grad + checkpoint(self.structure_blocks[i], structure_br) grad = grad + checkpoint(self.structure_blocks[i], structure_br)
out = checkpoint(self.upsample, fea) out = checkpoint(self.upsample, fea)
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out) return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea
def get_debug_values(self, step, net_name):
return { 'ref_join_std': self.ref_join_std }

View File

@ -126,9 +126,10 @@ def define_G(opt, net_key='network_G', scale=None):
elif which_model == 'chained_gen': elif which_model == 'chained_gen':
netG = ChainedEmbeddingGen(depth=opt_net['depth']) netG = ChainedEmbeddingGen(depth=opt_net['depth'])
elif which_model == 'chained_gen_structured': elif which_model == 'chained_gen_structured':
netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False) rec = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False
elif which_model == 'chained_gen_structuredr2': recnf = opt_net['recurrent_nf'] if 'recurrent_nf' in opt_net.keys() else 3
netG = ChainedEmbeddingGenWithStructureR2(depth=opt_net['depth']) recstd = opt_net['recurrent_stride'] if 'recurrent_stride' in opt_net.keys() else 2
netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd)
elif which_model == "flownet2": elif which_model == "flownet2":
from models.flownet2.models import FlowNet2 from models.flownet2.models import FlowNet2
ld = torch.load(opt_net['load_path']) ld = torch.load(opt_net['load_path'])

View File

@ -1,7 +1,6 @@
import torch.nn import torch.nn
from models.archs.SPSR_arch import ImageGradientNoPadding from models.archs.SPSR_arch import ImageGradientNoPadding
from utils.weight_scheduler import get_scheduler_for_opt from utils.weight_scheduler import get_scheduler_for_opt
#from models.steps.recursive_gen_injectors import ImageFlowInjector
from models.steps.losses import extract_params_from_state from models.steps.losses import extract_params_from_state
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions. # Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
@ -10,6 +9,9 @@ def create_injector(opt_inject, env):
if 'teco_' in type: if 'teco_' in type:
from models.steps.tecogan_losses import create_teco_injector from models.steps.tecogan_losses import create_teco_injector
return create_teco_injector(opt_inject, env) return create_teco_injector(opt_inject, env)
elif 'progressive_' in type:
from models.steps.progressive_zoom import create_progressive_zoom_injector
return create_progressive_zoom_injector(opt_inject, env)
elif type == 'generator': elif type == 'generator':
return ImageGeneratorInjector(opt_inject, env) return ImageGeneratorInjector(opt_inject, env)
elif type == 'discriminator': elif type == 'discriminator':

View File

@ -37,7 +37,7 @@ def create_loss(opt_loss, env):
# Converts params to a list of tensors extracted from state. Works with list/tuple params as well as scalars. # Converts params to a list of tensors extracted from state. Works with list/tuple params as well as scalars.
def extract_params_from_state(params, state, root=True): def extract_params_from_state(params: object, state: object, root: object = True) -> object:
if isinstance(params, list) or isinstance(params, tuple): if isinstance(params, list) or isinstance(params, tuple):
p = [extract_params_from_state(r, state, False) for r in params] p = [extract_params_from_state(r, state, False) for r in params]
elif isinstance(params, str): elif isinstance(params, str):

View File

@ -0,0 +1,120 @@
import os
import random
import torch
import torchvision
from data.multiscale_dataset import build_multiscale_patch_index_map
from models.steps.injectors import Injector
from models.steps.losses import extract_params_from_state
from models.steps.tecogan_losses import extract_inputs_index
import os.path as osp
def create_progressive_zoom_injector(opt, env):
type = opt['type']
if type == 'progressive_zoom_generator':
return ProgressiveGeneratorInjector(opt, env)
return None
class ProgressiveGeneratorInjector(Injector):
def __init__(self, opt, env):
super(ProgressiveGeneratorInjector, self).__init__(opt, env)
self.gen_key = opt['generator']
self.hq_key = opt['hq'] # The key where HQ images are stored.
self.hq_output_key = opt['hq_output'] # The key where HQ images corresponding with generated images are stored.
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
self.output_hq_index = opt['output_hq_index']
self.recurrent_output_index = opt['recurrent_output_index']
self.recurrent_index = opt['recurrent_index']
self.depth = opt['depth']
self.number_branches = opt['num_branches'] # Number of input branches to randomly choose for generation. This defines the output shape.
self.multiscale_leaves = build_multiscale_patch_index_map(self.depth)
self.feed_gen_output_into_input = opt['feed_gen_output_into_input']
# Given a set of multiscale inputs, selects self.num_branches leaves and produces that many chains of inputs,
# excluding the base input for efficiency reasons.
# Output is a list of chains. Each chain is itself a list of links. Each link is MultiscaleTreeNode
def get_input_chains(self):
leaves = random.sample(self.multiscale_leaves, self.number_branches)
chains = []
for leaf in leaves:
chain = [leaf]
node = leaf.parent
while node.parent is not None:
chain.insert(0, node)
node = node.parent
chains.append(chain)
return chains
def feed_forward(self, gen, inputs, results, lq_input, recurrent_input):
ff_input = inputs.copy()
ff_input[self.input_lq_index] = lq_input
ff_input[self.recurrent_index] = recurrent_input
gen_out = gen(*ff_input)
if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out]
for i, out_key in enumerate(self.output):
results[out_key].append(gen_out[i])
return gen_out[self.output_hq_index], gen_out[self.recurrent_output_index]
def forward(self, state):
gen = self.env['generators'][self.gen_key]
inputs = extract_params_from_state(self.input, state)
lq_inputs = inputs[self.input_lq_index]
hq_inputs = state[self.hq_key]
if not isinstance(inputs, list):
inputs = [inputs]
if not isinstance(self.output, list):
self.output = [self.output]
results = {} # A list of outputs produced by feeding each progressive lq input into the generator.
results_hq = []
for out_key in self.output:
results[out_key] = []
b, f, h, w = lq_inputs[:, 0].shape
base_hq_out, base_recurrent = self.feed_forward(gen, inputs, results, lq_inputs[:, 0], None)
results_hq.append(hq_inputs[:, 0])
input_chains = self.get_input_chains()
debug_index = 0
for chain in input_chains:
chain_input = [lq_inputs[:, 0]]
chain_output = [base_hq_out]
recurrent_hq = base_hq_out
recurrent = base_recurrent
for link in chain: # Remember, `link` is a MultiscaleTreeNode.
if self.feed_gen_output_into_input:
top = int(link.top * 2 * h)
left = int(link.left * 2 * w)
lq_input = recurrent_hq[:, :, top:top+h, left:left+w]
else:
lq_input = lq_inputs[:, link.index]
chain_input.append(lq_input)
recurrent_hq, recurrent = self.feed_forward(gen, inputs, results, lq_input, recurrent)
chain_output.append(recurrent_hq)
results_hq.append(hq_inputs[:, link.index])
if self.env['step'] % 1 == 0:
self.produce_progressive_visual_debugs(chain_input, chain_output, debug_index)
debug_index += 1
results[self.hq_output_key] = results_hq
for k, v in results.items():
results[k] = torch.stack(v, dim=1)
return results
def produce_progressive_visual_debugs(self, chain_inputs, chain_outputs, it):
if self.env['rank'] > 0:
return
if self.feed_gen_output_into_input:
lbl = 'generator_recurrent'
else:
lbl = 'generator_regular'
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", lbl, str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
ind = 1
for i, o in zip(chain_inputs, chain_outputs):
torchvision.utils.save_image(i, osp.join(base_path, "%s_%i_input.png" % (it, ind)))
torchvision.utils.save_image(o, osp.join(base_path, "%s_%i_output.png" % (it, ind)))
ind += 1

View File

@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structuredr2.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_chained.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_constrained.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_chained_structured.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()