Multiscale training in!
This commit is contained in:
parent
e706911c83
commit
9ead2c0a08
|
@ -37,6 +37,8 @@ def create_dataset(dataset_opt):
|
|||
from data.multi_frame_dataset import MultiFrameDataset as D
|
||||
elif mode == 'combined':
|
||||
from data.combined_dataset import CombinedDataset as D
|
||||
elif mode == 'multiscale':
|
||||
from data.multiscale_dataset import MultiScaleDataset as D
|
||||
else:
|
||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||
dataset = D(dataset_opt)
|
||||
|
|
|
@ -19,7 +19,7 @@ class MultiScaleDataset(data.Dataset):
|
|||
self.num_scales = self.opt['num_scales']
|
||||
self.hq_size_cap = self.tile_size * 2 ** self.num_scales
|
||||
self.scale = self.opt['scale']
|
||||
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['dataroot'], [1])
|
||||
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1])
|
||||
|
||||
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
||||
# offset from center is chosen on a normal probability curve.
|
||||
|
@ -43,7 +43,7 @@ class MultiScaleDataset(data.Dataset):
|
|||
if depth >= self.num_scales:
|
||||
return
|
||||
patch_size = self.hq_size_cap // (2 ** depth)
|
||||
# First pull the four sub-patches.
|
||||
# First pull the four sub-patches. Important: if this is changed, be sure to edit build_multiscale_patch_index_map() below.
|
||||
patches = [input_img[:patch_size, :patch_size],
|
||||
input_img[:patch_size, patch_size:],
|
||||
input_img[patch_size:, :patch_size],
|
||||
|
@ -67,20 +67,29 @@ class MultiScaleDataset(data.Dataset):
|
|||
if patches_hq[0].shape[2] == 3:
|
||||
patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq]
|
||||
patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq]
|
||||
patches_hq = torch.stack(patches_hq, dim=0)
|
||||
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='bilinear').squeeze() for p in patches_hq]
|
||||
patches_lq = torch.stack(patches_lq, dim=0)
|
||||
|
||||
d = {'LQ': patches_lq, 'HQ': patches_hq, 'GT_path': full_path}
|
||||
d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path}
|
||||
return d
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths_hq)
|
||||
|
||||
class MultiscaleTreeNode:
|
||||
def __init__(self, index, parent):
|
||||
def __init__(self, index, parent, i):
|
||||
self.index = index
|
||||
self.parent = parent
|
||||
self.children = []
|
||||
|
||||
# These represent the offset from left and top of the image for the individual patch as a proportion of the entire image.
|
||||
# Tightly tied to the implementation above for the order in which the patches are pulled from the base image.
|
||||
lefts = [0, .5, 0, .5]
|
||||
tops = [0, 0, .5, .5]
|
||||
self.left = lefts[i]
|
||||
self.top = tops[i]
|
||||
|
||||
def add_child(self, child):
|
||||
self.children.append(child)
|
||||
return child
|
||||
|
@ -89,14 +98,14 @@ class MultiscaleTreeNode:
|
|||
def build_multiscale_patch_index_map(depth):
|
||||
if depth < 0:
|
||||
return
|
||||
root = MultiscaleTreeNode(0, None)
|
||||
root = MultiscaleTreeNode(0, None, 0)
|
||||
leaves = []
|
||||
_build_multiscale_patch_index_map(depth-1, 1, root, leaves)
|
||||
return leaves
|
||||
|
||||
|
||||
def _build_multiscale_patch_index_map(depth, ind, node, leaves):
|
||||
subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node)) for i in range(4)]
|
||||
subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node, i)) for i in range(4)]
|
||||
ind += 4
|
||||
if depth == 1:
|
||||
leaves.extend(subnodes)
|
||||
|
@ -109,7 +118,7 @@ def _build_multiscale_patch_index_map(depth, ind, node, leaves):
|
|||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'name': 'amalgam',
|
||||
'dataroot': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images'],
|
||||
'dataroot': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half'],
|
||||
'num_scales': 4,
|
||||
'scale': 2,
|
||||
'hq_tile_size': 128
|
||||
|
|
|
@ -66,12 +66,12 @@ class ChainedEmbeddingGen(nn.Module):
|
|||
|
||||
|
||||
class ChainedEmbeddingGenWithStructure(nn.Module):
|
||||
def __init__(self, depth=10, recurrent=False):
|
||||
def __init__(self, depth=10, recurrent=False, recurrent_nf=3, recurrent_stride=2):
|
||||
super(ChainedEmbeddingGenWithStructure, self).__init__()
|
||||
self.recurrent = recurrent
|
||||
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||
if recurrent:
|
||||
self.recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False)
|
||||
self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False)
|
||||
self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
|
||||
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
||||
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
|
||||
|
@ -80,12 +80,16 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
|
|||
self.structure_upsample = FinalUpsampleBlock2x(64)
|
||||
self.grad_extract = ImageGradientNoPadding()
|
||||
self.upsample = FinalUpsampleBlock2x(64)
|
||||
self.ref_join_std = 0
|
||||
|
||||
def forward(self, x, recurrent=None):
|
||||
fea = self.initial_conv(x)
|
||||
if self.recurrent:
|
||||
if recurrent is None:
|
||||
recurrent = torch.zeros_like(fea)
|
||||
rec = self.recurrent_process(recurrent)
|
||||
fea, _ = self.recurrent_join(fea, rec)
|
||||
fea, recstd = self.recurrent_join(fea, rec)
|
||||
self.ref_join_std = recstd.item()
|
||||
emb = checkpoint(self.spine, fea)
|
||||
grad = fea
|
||||
for i, block in enumerate(self.blocks):
|
||||
|
@ -94,4 +98,7 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
|
|||
structure_br = checkpoint(self.structure_joins[i], grad, fea)
|
||||
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
|
||||
out = checkpoint(self.upsample, fea)
|
||||
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out)
|
||||
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea
|
||||
|
||||
def get_debug_values(self, step, net_name):
|
||||
return { 'ref_join_std': self.ref_join_std }
|
|
@ -126,9 +126,10 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
elif which_model == 'chained_gen':
|
||||
netG = ChainedEmbeddingGen(depth=opt_net['depth'])
|
||||
elif which_model == 'chained_gen_structured':
|
||||
netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False)
|
||||
elif which_model == 'chained_gen_structuredr2':
|
||||
netG = ChainedEmbeddingGenWithStructureR2(depth=opt_net['depth'])
|
||||
rec = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False
|
||||
recnf = opt_net['recurrent_nf'] if 'recurrent_nf' in opt_net.keys() else 3
|
||||
recstd = opt_net['recurrent_stride'] if 'recurrent_stride' in opt_net.keys() else 2
|
||||
netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd)
|
||||
elif which_model == "flownet2":
|
||||
from models.flownet2.models import FlowNet2
|
||||
ld = torch.load(opt_net['load_path'])
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import torch.nn
|
||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
#from models.steps.recursive_gen_injectors import ImageFlowInjector
|
||||
from models.steps.losses import extract_params_from_state
|
||||
|
||||
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
||||
|
@ -10,6 +9,9 @@ def create_injector(opt_inject, env):
|
|||
if 'teco_' in type:
|
||||
from models.steps.tecogan_losses import create_teco_injector
|
||||
return create_teco_injector(opt_inject, env)
|
||||
elif 'progressive_' in type:
|
||||
from models.steps.progressive_zoom import create_progressive_zoom_injector
|
||||
return create_progressive_zoom_injector(opt_inject, env)
|
||||
elif type == 'generator':
|
||||
return ImageGeneratorInjector(opt_inject, env)
|
||||
elif type == 'discriminator':
|
||||
|
|
|
@ -37,7 +37,7 @@ def create_loss(opt_loss, env):
|
|||
|
||||
|
||||
# Converts params to a list of tensors extracted from state. Works with list/tuple params as well as scalars.
|
||||
def extract_params_from_state(params, state, root=True):
|
||||
def extract_params_from_state(params: object, state: object, root: object = True) -> object:
|
||||
if isinstance(params, list) or isinstance(params, tuple):
|
||||
p = [extract_params_from_state(r, state, False) for r in params]
|
||||
elif isinstance(params, str):
|
||||
|
|
120
codes/models/steps/progressive_zoom.py
Normal file
120
codes/models/steps/progressive_zoom.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from data.multiscale_dataset import build_multiscale_patch_index_map
|
||||
from models.steps.injectors import Injector
|
||||
from models.steps.losses import extract_params_from_state
|
||||
from models.steps.tecogan_losses import extract_inputs_index
|
||||
import os.path as osp
|
||||
|
||||
|
||||
def create_progressive_zoom_injector(opt, env):
|
||||
type = opt['type']
|
||||
if type == 'progressive_zoom_generator':
|
||||
return ProgressiveGeneratorInjector(opt, env)
|
||||
return None
|
||||
|
||||
|
||||
class ProgressiveGeneratorInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ProgressiveGeneratorInjector, self).__init__(opt, env)
|
||||
self.gen_key = opt['generator']
|
||||
self.hq_key = opt['hq'] # The key where HQ images are stored.
|
||||
self.hq_output_key = opt['hq_output'] # The key where HQ images corresponding with generated images are stored.
|
||||
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
||||
self.output_hq_index = opt['output_hq_index']
|
||||
self.recurrent_output_index = opt['recurrent_output_index']
|
||||
self.recurrent_index = opt['recurrent_index']
|
||||
self.depth = opt['depth']
|
||||
self.number_branches = opt['num_branches'] # Number of input branches to randomly choose for generation. This defines the output shape.
|
||||
self.multiscale_leaves = build_multiscale_patch_index_map(self.depth)
|
||||
self.feed_gen_output_into_input = opt['feed_gen_output_into_input']
|
||||
|
||||
# Given a set of multiscale inputs, selects self.num_branches leaves and produces that many chains of inputs,
|
||||
# excluding the base input for efficiency reasons.
|
||||
# Output is a list of chains. Each chain is itself a list of links. Each link is MultiscaleTreeNode
|
||||
def get_input_chains(self):
|
||||
leaves = random.sample(self.multiscale_leaves, self.number_branches)
|
||||
chains = []
|
||||
for leaf in leaves:
|
||||
chain = [leaf]
|
||||
node = leaf.parent
|
||||
while node.parent is not None:
|
||||
chain.insert(0, node)
|
||||
node = node.parent
|
||||
chains.append(chain)
|
||||
return chains
|
||||
|
||||
def feed_forward(self, gen, inputs, results, lq_input, recurrent_input):
|
||||
ff_input = inputs.copy()
|
||||
ff_input[self.input_lq_index] = lq_input
|
||||
ff_input[self.recurrent_index] = recurrent_input
|
||||
gen_out = gen(*ff_input)
|
||||
if isinstance(gen_out, torch.Tensor):
|
||||
gen_out = [gen_out]
|
||||
for i, out_key in enumerate(self.output):
|
||||
results[out_key].append(gen_out[i])
|
||||
return gen_out[self.output_hq_index], gen_out[self.recurrent_output_index]
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.gen_key]
|
||||
inputs = extract_params_from_state(self.input, state)
|
||||
lq_inputs = inputs[self.input_lq_index]
|
||||
hq_inputs = state[self.hq_key]
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
if not isinstance(self.output, list):
|
||||
self.output = [self.output]
|
||||
results = {} # A list of outputs produced by feeding each progressive lq input into the generator.
|
||||
results_hq = []
|
||||
for out_key in self.output:
|
||||
results[out_key] = []
|
||||
|
||||
b, f, h, w = lq_inputs[:, 0].shape
|
||||
base_hq_out, base_recurrent = self.feed_forward(gen, inputs, results, lq_inputs[:, 0], None)
|
||||
results_hq.append(hq_inputs[:, 0])
|
||||
input_chains = self.get_input_chains()
|
||||
debug_index = 0
|
||||
for chain in input_chains:
|
||||
chain_input = [lq_inputs[:, 0]]
|
||||
chain_output = [base_hq_out]
|
||||
recurrent_hq = base_hq_out
|
||||
recurrent = base_recurrent
|
||||
for link in chain: # Remember, `link` is a MultiscaleTreeNode.
|
||||
if self.feed_gen_output_into_input:
|
||||
top = int(link.top * 2 * h)
|
||||
left = int(link.left * 2 * w)
|
||||
lq_input = recurrent_hq[:, :, top:top+h, left:left+w]
|
||||
else:
|
||||
lq_input = lq_inputs[:, link.index]
|
||||
chain_input.append(lq_input)
|
||||
recurrent_hq, recurrent = self.feed_forward(gen, inputs, results, lq_input, recurrent)
|
||||
chain_output.append(recurrent_hq)
|
||||
results_hq.append(hq_inputs[:, link.index])
|
||||
|
||||
if self.env['step'] % 1 == 0:
|
||||
self.produce_progressive_visual_debugs(chain_input, chain_output, debug_index)
|
||||
debug_index += 1
|
||||
results[self.hq_output_key] = results_hq
|
||||
for k, v in results.items():
|
||||
results[k] = torch.stack(v, dim=1)
|
||||
return results
|
||||
|
||||
|
||||
def produce_progressive_visual_debugs(self, chain_inputs, chain_outputs, it):
|
||||
if self.env['rank'] > 0:
|
||||
return
|
||||
if self.feed_gen_output_into_input:
|
||||
lbl = 'generator_recurrent'
|
||||
else:
|
||||
lbl = 'generator_regular'
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", lbl, str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
ind = 1
|
||||
for i, o in zip(chain_inputs, chain_outputs):
|
||||
torchvision.utils.save_image(i, osp.join(base_path, "%s_%i_input.png" % (it, ind)))
|
||||
torchvision.utils.save_image(o, osp.join(base_path, "%s_%i_output.png" % (it, ind)))
|
||||
ind += 1
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structuredr2.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_chained.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_constrained.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_chained_structured.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user