forked from mrq/DL-Art-School
43c4f92123
This contributes a significant speedup to training this type of network since losses can operate on the entire prediction spectrum at once.
127 lines
5.5 KiB
Python
127 lines
5.5 KiB
Python
import os
|
|
import random
|
|
|
|
import torch
|
|
import torchvision
|
|
|
|
from data.multiscale_dataset import build_multiscale_patch_index_map
|
|
from models.steps.injectors import Injector
|
|
from models.steps.losses import extract_params_from_state
|
|
from models.steps.tecogan_losses import extract_inputs_index
|
|
import os.path as osp
|
|
|
|
|
|
def create_progressive_zoom_injector(opt, env):
|
|
type = opt['type']
|
|
if type == 'progressive_zoom_generator':
|
|
return ProgressiveGeneratorInjector(opt, env)
|
|
return None
|
|
|
|
|
|
class ProgressiveGeneratorInjector(Injector):
|
|
def __init__(self, opt, env):
|
|
super(ProgressiveGeneratorInjector, self).__init__(opt, env)
|
|
self.gen_key = opt['generator']
|
|
self.hq_key = opt['hq'] # The key where HQ images are stored.
|
|
self.hq_output_key = opt['hq_output'] # The key where HQ images corresponding with generated images are stored.
|
|
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
|
self.output_hq_index = opt['output_hq_index']
|
|
self.recurrent_output_index = opt['recurrent_output_index']
|
|
self.recurrent_index = opt['recurrent_index']
|
|
self.depth = opt['depth']
|
|
self.number_branches = opt['num_branches'] # Number of input branches to randomly choose for generation. This defines the output shape.
|
|
self.multiscale_leaves = build_multiscale_patch_index_map(self.depth)
|
|
self.feed_gen_output_into_input = opt['feed_gen_output_into_input']
|
|
|
|
# Given a set of multiscale inputs, selects self.num_branches leaves and produces that many chains of inputs,
|
|
# excluding the base input for efficiency reasons.
|
|
# Output is a list of chains. Each chain is itself a list of links. Each link is MultiscaleTreeNode
|
|
def get_input_chains(self):
|
|
leaves = random.sample(self.multiscale_leaves, self.number_branches)
|
|
chains = []
|
|
for leaf in leaves:
|
|
chain = [leaf]
|
|
node = leaf.parent
|
|
while node.parent is not None:
|
|
chain.insert(0, node)
|
|
node = node.parent
|
|
chains.append(chain)
|
|
return chains
|
|
|
|
def feed_forward(self, gen, inputs, results, lq_input, recurrent_input):
|
|
ff_input = inputs.copy()
|
|
ff_input[self.input_lq_index] = lq_input
|
|
ff_input[self.recurrent_index] = recurrent_input
|
|
gen_out = gen(*ff_input)
|
|
if isinstance(gen_out, torch.Tensor):
|
|
gen_out = [gen_out]
|
|
for i, out_key in enumerate(self.output):
|
|
results[out_key].append(gen_out[i])
|
|
return gen_out[self.output_hq_index], gen_out[self.recurrent_output_index]
|
|
|
|
def forward(self, state):
|
|
gen = self.env['generators'][self.gen_key]
|
|
inputs = extract_params_from_state(self.input, state)
|
|
lq_inputs = inputs[self.input_lq_index]
|
|
hq_inputs = state[self.hq_key]
|
|
output = self.output
|
|
if not isinstance(inputs, list):
|
|
inputs = [inputs]
|
|
if not isinstance(self.output, list):
|
|
output = [self.output]
|
|
results = {} # A list of outputs produced by feeding each progressive lq input into the generator.
|
|
results_hq = []
|
|
for out_key in output:
|
|
results[out_key] = []
|
|
|
|
b, f, h, w = lq_inputs[:, 0].shape
|
|
base_hq_out, base_recurrent = self.feed_forward(gen, inputs, results, lq_inputs[:, 0], None)
|
|
results_hq.append(hq_inputs[:, 0])
|
|
input_chains = self.get_input_chains()
|
|
debug_index = 0
|
|
for chain in input_chains:
|
|
chain_input = [lq_inputs[:, 0]]
|
|
chain_output = [base_hq_out]
|
|
recurrent_hq = base_hq_out
|
|
recurrent = base_recurrent
|
|
for link in chain: # Remember, `link` is a MultiscaleTreeNode.
|
|
top = int(link.top * h)
|
|
left = int(link.left * w)
|
|
recurrent = torch.nn.functional.interpolate(recurrent[:, :, top:top+h//2, left:left+w//2], scale_factor=2, mode="nearest")
|
|
if self.feed_gen_output_into_input:
|
|
top *= 2
|
|
left *= 2
|
|
lq_input = recurrent_hq[:, :, top:top+h, left:left+w]
|
|
else:
|
|
lq_input = lq_inputs[:, link.index]
|
|
chain_input.append(lq_input)
|
|
recurrent_hq, recurrent = self.feed_forward(gen, inputs, results, lq_input, recurrent)
|
|
chain_output.append(recurrent_hq)
|
|
results_hq.append(hq_inputs[:, link.index])
|
|
|
|
if self.env['step'] % 50 == 0:
|
|
self.produce_progressive_visual_debugs(chain_input, chain_output, debug_index)
|
|
debug_index += 1
|
|
results[self.hq_output_key] = results_hq
|
|
|
|
# Results are concatenated into the batch dimension, to allow normal losses to be used against the output.
|
|
for k, v in results.items():
|
|
results[k] = torch.cat(v, dim=0)
|
|
return results
|
|
|
|
|
|
def produce_progressive_visual_debugs(self, chain_inputs, chain_outputs, it):
|
|
if self.env['rank'] > 0:
|
|
return
|
|
if self.feed_gen_output_into_input:
|
|
lbl = 'generator_recurrent'
|
|
else:
|
|
lbl = 'generator_regular'
|
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", lbl, str(self.env['step']))
|
|
os.makedirs(base_path, exist_ok=True)
|
|
ind = 1
|
|
for i, o in zip(chain_inputs, chain_outputs):
|
|
torchvision.utils.save_image(i, osp.join(base_path, "%s_%i_input.png" % (it, ind)))
|
|
torchvision.utils.save_image(o, osp.join(base_path, "%s_%i_output.png" % (it, ind)))
|
|
ind += 1
|