forked from mrq/DL-Art-School
140 lines
5.9 KiB
Python
140 lines
5.9 KiB
Python
import os
|
|
import random
|
|
|
|
import torch
|
|
import torchvision
|
|
from torch.cuda.amp import autocast
|
|
|
|
from data.multiscale_dataset import build_multiscale_patch_index_map
|
|
from trainer.injectors import Injector
|
|
from trainer.losses import extract_params_from_state
|
|
import os.path as osp
|
|
|
|
|
|
def create_progressive_zoom_injector(opt, env):
|
|
type = opt['type']
|
|
if type == 'progressive_zoom_generator':
|
|
return ProgressiveGeneratorInjector(opt, env)
|
|
return None
|
|
|
|
|
|
class ProgressiveGeneratorInjector(Injector):
|
|
def __init__(self, opt, env):
|
|
super(ProgressiveGeneratorInjector, self).__init__(opt, env)
|
|
self.gen_key = opt['generator']
|
|
self.hq_key = opt['hq'] # The key where HQ images are stored.
|
|
self.hq_output_key = opt['hq_output'] # The key where HQ images corresponding with generated images are stored.
|
|
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
|
self.output_hq_index = opt['output_hq_index']
|
|
if 'recurrent_output_index' in opt.keys():
|
|
self.recurrent_output_index = opt['recurrent_output_index']
|
|
self.recurrent_index = opt['recurrent_index']
|
|
self.recurrence = True
|
|
else:
|
|
self.recurrence = False
|
|
self.depth = opt['depth']
|
|
self.number_branches = opt['num_branches'] # Number of input branches to randomly choose for generation. This defines the output shape.
|
|
self.multiscale_leaves = build_multiscale_patch_index_map(self.depth)
|
|
self.feed_gen_output_into_input = opt['feed_gen_output_into_input']
|
|
|
|
# Given a set of multiscale inputs, selects self.num_branches leaves and produces that many chains of inputs,
|
|
# excluding the base input for efficiency reasons.
|
|
# Output is a list of chains. Each chain is itself a list of links. Each link is MultiscaleTreeNode
|
|
def get_input_chains(self):
|
|
leaves = random.sample(self.multiscale_leaves, self.number_branches)
|
|
chains = []
|
|
for leaf in leaves:
|
|
chain = [leaf]
|
|
node = leaf.parent
|
|
while node.parent is not None:
|
|
chain.insert(0, node)
|
|
node = node.parent
|
|
chains.append(chain)
|
|
return chains
|
|
|
|
def feed_forward(self, gen, inputs, results, lq_input, recurrent_input):
|
|
ff_input = inputs.copy()
|
|
ff_input[self.input_lq_index] = lq_input
|
|
if self.recurrence:
|
|
ff_input[self.recurrent_index] = recurrent_input
|
|
|
|
with autocast(enabled=self.env['opt']['fp16']):
|
|
gen_out = gen(*ff_input)
|
|
|
|
if isinstance(gen_out, torch.Tensor):
|
|
gen_out = [gen_out]
|
|
for i, out_key in enumerate(self.output):
|
|
results[out_key].append(gen_out[i])
|
|
recurrent = None
|
|
if self.recurrence:
|
|
recurrent = gen_out[self.recurrent_output_index]
|
|
return gen_out[self.output_hq_index], recurrent
|
|
|
|
def forward(self, state):
|
|
gen = self.env['generators'][self.gen_key]
|
|
inputs = extract_params_from_state(self.input, state)
|
|
lq_inputs = inputs[self.input_lq_index]
|
|
hq_inputs = state[self.hq_key]
|
|
output = self.output
|
|
if not isinstance(inputs, list):
|
|
inputs = [inputs]
|
|
if not isinstance(self.output, list):
|
|
output = [self.output]
|
|
self.output = output
|
|
results = {} # A list of outputs produced by feeding each progressive lq input into the generator.
|
|
results_hq = []
|
|
for out_key in output:
|
|
results[out_key] = []
|
|
|
|
b, f, h, w = lq_inputs[:, 0].shape
|
|
base_hq_out, base_recurrent = self.feed_forward(gen, inputs, results, lq_inputs[:, 0], None)
|
|
results_hq.append(hq_inputs[:, 0])
|
|
input_chains = self.get_input_chains()
|
|
debug_index = 0
|
|
for chain in input_chains:
|
|
chain_input = [lq_inputs[:, 0]]
|
|
chain_output = [base_hq_out]
|
|
recurrent_hq = base_hq_out
|
|
recurrent = base_recurrent
|
|
for link in chain: # Remember, `link` is a MultiscaleTreeNode.
|
|
top = int(link.top * h)
|
|
left = int(link.left * w)
|
|
if recurrent is not None:
|
|
recurrent = torch.nn.functional.interpolate(recurrent[:, :, top:top+h//2, left:left+w//2], scale_factor=2, mode="nearest")
|
|
if self.feed_gen_output_into_input:
|
|
top *= 2
|
|
left *= 2
|
|
lq_input = recurrent_hq[:, :, top:top+h, left:left+w]
|
|
else:
|
|
lq_input = lq_inputs[:, link.index]
|
|
chain_input.append(lq_input)
|
|
recurrent_hq, recurrent = self.feed_forward(gen, inputs, results, lq_input, recurrent)
|
|
chain_output.append(recurrent_hq)
|
|
results_hq.append(hq_inputs[:, link.index])
|
|
|
|
if self.env['step'] % 50 == 0:
|
|
self.produce_progressive_visual_debugs(chain_input, chain_output, debug_index)
|
|
debug_index += 1
|
|
results[self.hq_output_key] = results_hq
|
|
|
|
# Results are concatenated into the batch dimension, to allow normal losses to be used against the output.
|
|
for k, v in results.items():
|
|
results[k] = torch.cat(v, dim=0)
|
|
return results
|
|
|
|
|
|
def produce_progressive_visual_debugs(self, chain_inputs, chain_outputs, it):
|
|
if self.env['rank'] > 0:
|
|
return
|
|
if self.feed_gen_output_into_input:
|
|
lbl = 'generator_recurrent'
|
|
else:
|
|
lbl = 'generator_regular'
|
|
base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", lbl, str(self.env['step']))
|
|
os.makedirs(base_path, exist_ok=True)
|
|
ind = 1
|
|
for i, o in zip(chain_inputs, chain_outputs):
|
|
torchvision.utils.save_image(i.float(), osp.join(base_path, "%s_%i_input.png" % (it, ind)))
|
|
torchvision.utils.save_image(o.float(), osp.join(base_path, "%s_%i_output.png" % (it, ind)))
|
|
ind += 1
|