DL-Art-School/codes/trainer/custom_training_components/progressive_zoom.py
James Betker 63cf3d3126 Injector auto-registration
I love it!
2020-12-29 20:58:02 -07:00

140 lines
5.9 KiB
Python

import os
import random
import torch
import torchvision
from torch.cuda.amp import autocast
from data.multiscale_dataset import build_multiscale_patch_index_map
from trainer.inject import Injector
from trainer.losses import extract_params_from_state
import os.path as osp
def create_progressive_zoom_injector(opt, env):
type = opt['type']
if type == 'progressive_zoom_generator':
return ProgressiveGeneratorInjector(opt, env)
return None
class ProgressiveGeneratorInjector(Injector):
def __init__(self, opt, env):
super(ProgressiveGeneratorInjector, self).__init__(opt, env)
self.gen_key = opt['generator']
self.hq_key = opt['hq'] # The key where HQ images are stored.
self.hq_output_key = opt['hq_output'] # The key where HQ images corresponding with generated images are stored.
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
self.output_hq_index = opt['output_hq_index']
if 'recurrent_output_index' in opt.keys():
self.recurrent_output_index = opt['recurrent_output_index']
self.recurrent_index = opt['recurrent_index']
self.recurrence = True
else:
self.recurrence = False
self.depth = opt['depth']
self.number_branches = opt['num_branches'] # Number of input branches to randomly choose for generation. This defines the output shape.
self.multiscale_leaves = build_multiscale_patch_index_map(self.depth)
self.feed_gen_output_into_input = opt['feed_gen_output_into_input']
# Given a set of multiscale inputs, selects self.num_branches leaves and produces that many chains of inputs,
# excluding the base input for efficiency reasons.
# Output is a list of chains. Each chain is itself a list of links. Each link is MultiscaleTreeNode
def get_input_chains(self):
leaves = random.sample(self.multiscale_leaves, self.number_branches)
chains = []
for leaf in leaves:
chain = [leaf]
node = leaf.parent
while node.parent is not None:
chain.insert(0, node)
node = node.parent
chains.append(chain)
return chains
def feed_forward(self, gen, inputs, results, lq_input, recurrent_input):
ff_input = inputs.copy()
ff_input[self.input_lq_index] = lq_input
if self.recurrence:
ff_input[self.recurrent_index] = recurrent_input
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*ff_input)
if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out]
for i, out_key in enumerate(self.output):
results[out_key].append(gen_out[i])
recurrent = None
if self.recurrence:
recurrent = gen_out[self.recurrent_output_index]
return gen_out[self.output_hq_index], recurrent
def forward(self, state):
gen = self.env['generators'][self.gen_key]
inputs = extract_params_from_state(self.input, state)
lq_inputs = inputs[self.input_lq_index]
hq_inputs = state[self.hq_key]
output = self.output
if not isinstance(inputs, list):
inputs = [inputs]
if not isinstance(self.output, list):
output = [self.output]
self.output = output
results = {} # A list of outputs produced by feeding each progressive lq input into the generator.
results_hq = []
for out_key in output:
results[out_key] = []
b, f, h, w = lq_inputs[:, 0].shape
base_hq_out, base_recurrent = self.feed_forward(gen, inputs, results, lq_inputs[:, 0], None)
results_hq.append(hq_inputs[:, 0])
input_chains = self.get_input_chains()
debug_index = 0
for chain in input_chains:
chain_input = [lq_inputs[:, 0]]
chain_output = [base_hq_out]
recurrent_hq = base_hq_out
recurrent = base_recurrent
for link in chain: # Remember, `link` is a MultiscaleTreeNode.
top = int(link.top * h)
left = int(link.left * w)
if recurrent is not None:
recurrent = torch.nn.functional.interpolate(recurrent[:, :, top:top+h//2, left:left+w//2], scale_factor=2, mode="nearest")
if self.feed_gen_output_into_input:
top *= 2
left *= 2
lq_input = recurrent_hq[:, :, top:top+h, left:left+w]
else:
lq_input = lq_inputs[:, link.index]
chain_input.append(lq_input)
recurrent_hq, recurrent = self.feed_forward(gen, inputs, results, lq_input, recurrent)
chain_output.append(recurrent_hq)
results_hq.append(hq_inputs[:, link.index])
if self.env['step'] % 50 == 0:
self.produce_progressive_visual_debugs(chain_input, chain_output, debug_index)
debug_index += 1
results[self.hq_output_key] = results_hq
# Results are concatenated into the batch dimension, to allow normal losses to be used against the output.
for k, v in results.items():
results[k] = torch.cat(v, dim=0)
return results
def produce_progressive_visual_debugs(self, chain_inputs, chain_outputs, it):
if self.env['rank'] > 0:
return
if self.feed_gen_output_into_input:
lbl = 'generator_recurrent'
else:
lbl = 'generator_regular'
base_path = osp.join(self.env['base_path'], "../../models", "visual_dbg", lbl, str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
ind = 1
for i, o in zip(chain_inputs, chain_outputs):
torchvision.utils.save_image(i.float(), osp.join(base_path, "%s_%i_input.png" % (it, ind)))
torchvision.utils.save_image(o.float(), osp.join(base_path, "%s_%i_output.png" % (it, ind)))
ind += 1