15e00e9014
Note: autocast is broken when also using checkpoint(). Overcome this by modifying torch's checkpoint() function in place to also use autocast.
131 lines
5.7 KiB
Python
131 lines
5.7 KiB
Python
import os
|
|
import random
|
|
|
|
import torch
|
|
import torchvision
|
|
from torch.cuda.amp import autocast
|
|
|
|
from data.multiscale_dataset import build_multiscale_patch_index_map
|
|
from models.steps.injectors import Injector
|
|
from models.steps.losses import extract_params_from_state
|
|
from models.steps.tecogan_losses import extract_inputs_index
|
|
import os.path as osp
|
|
|
|
|
|
def create_progressive_zoom_injector(opt, env):
|
|
type = opt['type']
|
|
if type == 'progressive_zoom_generator':
|
|
return ProgressiveGeneratorInjector(opt, env)
|
|
return None
|
|
|
|
|
|
class ProgressiveGeneratorInjector(Injector):
|
|
def __init__(self, opt, env):
|
|
super(ProgressiveGeneratorInjector, self).__init__(opt, env)
|
|
self.gen_key = opt['generator']
|
|
self.hq_key = opt['hq'] # The key where HQ images are stored.
|
|
self.hq_output_key = opt['hq_output'] # The key where HQ images corresponding with generated images are stored.
|
|
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
|
self.output_hq_index = opt['output_hq_index']
|
|
self.recurrent_output_index = opt['recurrent_output_index']
|
|
self.recurrent_index = opt['recurrent_index']
|
|
self.depth = opt['depth']
|
|
self.number_branches = opt['num_branches'] # Number of input branches to randomly choose for generation. This defines the output shape.
|
|
self.multiscale_leaves = build_multiscale_patch_index_map(self.depth)
|
|
self.feed_gen_output_into_input = opt['feed_gen_output_into_input']
|
|
|
|
# Given a set of multiscale inputs, selects self.num_branches leaves and produces that many chains of inputs,
|
|
# excluding the base input for efficiency reasons.
|
|
# Output is a list of chains. Each chain is itself a list of links. Each link is MultiscaleTreeNode
|
|
def get_input_chains(self):
|
|
leaves = random.sample(self.multiscale_leaves, self.number_branches)
|
|
chains = []
|
|
for leaf in leaves:
|
|
chain = [leaf]
|
|
node = leaf.parent
|
|
while node.parent is not None:
|
|
chain.insert(0, node)
|
|
node = node.parent
|
|
chains.append(chain)
|
|
return chains
|
|
|
|
def feed_forward(self, gen, inputs, results, lq_input, recurrent_input):
|
|
ff_input = inputs.copy()
|
|
ff_input[self.input_lq_index] = lq_input
|
|
ff_input[self.recurrent_index] = recurrent_input
|
|
|
|
with autocast(enabled=self.env['opt']['fp16']):
|
|
gen_out = gen(*ff_input)
|
|
|
|
if isinstance(gen_out, torch.Tensor):
|
|
gen_out = [gen_out]
|
|
for i, out_key in enumerate(self.output):
|
|
results[out_key].append(gen_out[i])
|
|
return gen_out[self.output_hq_index], gen_out[self.recurrent_output_index]
|
|
|
|
def forward(self, state):
|
|
gen = self.env['generators'][self.gen_key]
|
|
inputs = extract_params_from_state(self.input, state)
|
|
lq_inputs = inputs[self.input_lq_index]
|
|
hq_inputs = state[self.hq_key]
|
|
output = self.output
|
|
if not isinstance(inputs, list):
|
|
inputs = [inputs]
|
|
if not isinstance(self.output, list):
|
|
output = [self.output]
|
|
results = {} # A list of outputs produced by feeding each progressive lq input into the generator.
|
|
results_hq = []
|
|
for out_key in output:
|
|
results[out_key] = []
|
|
|
|
b, f, h, w = lq_inputs[:, 0].shape
|
|
base_hq_out, base_recurrent = self.feed_forward(gen, inputs, results, lq_inputs[:, 0], None)
|
|
results_hq.append(hq_inputs[:, 0])
|
|
input_chains = self.get_input_chains()
|
|
debug_index = 0
|
|
for chain in input_chains:
|
|
chain_input = [lq_inputs[:, 0]]
|
|
chain_output = [base_hq_out]
|
|
recurrent_hq = base_hq_out
|
|
recurrent = base_recurrent
|
|
for link in chain: # Remember, `link` is a MultiscaleTreeNode.
|
|
top = int(link.top * h)
|
|
left = int(link.left * w)
|
|
recurrent = torch.nn.functional.interpolate(recurrent[:, :, top:top+h//2, left:left+w//2], scale_factor=2, mode="nearest")
|
|
if self.feed_gen_output_into_input:
|
|
top *= 2
|
|
left *= 2
|
|
lq_input = recurrent_hq[:, :, top:top+h, left:left+w]
|
|
else:
|
|
lq_input = lq_inputs[:, link.index]
|
|
chain_input.append(lq_input)
|
|
recurrent_hq, recurrent = self.feed_forward(gen, inputs, results, lq_input, recurrent)
|
|
chain_output.append(recurrent_hq)
|
|
results_hq.append(hq_inputs[:, link.index])
|
|
|
|
if self.env['step'] % 50 == 0:
|
|
self.produce_progressive_visual_debugs(chain_input, chain_output, debug_index)
|
|
debug_index += 1
|
|
results[self.hq_output_key] = results_hq
|
|
|
|
# Results are concatenated into the batch dimension, to allow normal losses to be used against the output.
|
|
for k, v in results.items():
|
|
results[k] = torch.cat(v, dim=0)
|
|
return results
|
|
|
|
|
|
def produce_progressive_visual_debugs(self, chain_inputs, chain_outputs, it):
|
|
if self.env['rank'] > 0:
|
|
return
|
|
if self.feed_gen_output_into_input:
|
|
lbl = 'generator_recurrent'
|
|
else:
|
|
lbl = 'generator_regular'
|
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", lbl, str(self.env['step']))
|
|
os.makedirs(base_path, exist_ok=True)
|
|
ind = 1
|
|
for i, o in zip(chain_inputs, chain_outputs):
|
|
torchvision.utils.save_image(i.float(), osp.join(base_path, "%s_%i_input.png" % (it, ind)))
|
|
torchvision.utils.save_image(o.float(), osp.join(base_path, "%s_%i_output.png" % (it, ind)))
|
|
ind += 1
|