Fixes and additional support for progressive zoom

This commit is contained in:
James Betker 2020-10-30 09:59:54 -06:00
parent a3918fa808
commit 74738489b9
7 changed files with 32 additions and 21 deletions

View File

@ -22,7 +22,7 @@ class MultiScaleDataset(data.Dataset):
self.num_scales = self.opt['num_scales'] self.num_scales = self.opt['num_scales']
self.hq_size_cap = self.tile_size * 2 ** self.num_scales self.hq_size_cap = self.tile_size * 2 ** self.num_scales
self.scale = self.opt['scale'] self.scale = self.opt['scale']
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1]) self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1 for _ in opt['paths']])
self.corruptor = ImageCorruptor(opt) self.corruptor = ImageCorruptor(opt)
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping # Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping

View File

@ -267,12 +267,13 @@ class ExtensibleTrainer(BaseModel):
with torch.no_grad(): with torch.no_grad():
# This can happen one of two ways: Either a 'validation injector' is provided, in which case we run that. # This can happen one of two ways: Either a 'validation injector' is provided, in which case we run that.
# Or, we run the entire chain of steps in "train" mode and use eval.output_state. # Or, we run the entire chain of steps in "train" mode and use eval.output_state.
if 'injector' in self.opt['eval'].keys(): if 'injectors' in self.opt['eval'].keys():
# Need to move from mega_batch mode to batch mode (remove chunks)
state = {} state = {}
for inj in self.opt['eval']['injectors'].values():
# Need to move from mega_batch mode to batch mode (remove chunks)
for k, v in self.dstate.items(): for k, v in self.dstate.items():
state[k] = v[0] state[k] = v[0]
inj = create_injector(self.opt['eval']['injector'], self.env) inj = create_injector(inj, self.env)
state.update(inj(state)) state.update(inj(state))
else: else:
# Iterate through the steps, performing them one at a time. # Iterate through the steps, performing them one at a time.

View File

@ -188,6 +188,8 @@ class RRDBNet(nn.Module):
if ref is None: if ref is None:
ref = torch.zeros_like(x_lg) ref = torch.zeros_like(x_lg)
x_lg = torch.cat([x_lg, ref], dim=1) x_lg = torch.cat([x_lg, ref], dim=1)
else:
x_lg = x
feat = self.conv_first(x_lg) feat = self.conv_first(x_lg)
body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)) body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat))
feat = feat + body_feat feat = feat + body_feat

View File

@ -27,8 +27,12 @@ class ProgressiveGeneratorInjector(Injector):
self.hq_output_key = opt['hq_output'] # The key where HQ images corresponding with generated images are stored. self.hq_output_key = opt['hq_output'] # The key where HQ images corresponding with generated images are stored.
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0 self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
self.output_hq_index = opt['output_hq_index'] self.output_hq_index = opt['output_hq_index']
if 'recurrent_output_index' in opt.keys():
self.recurrent_output_index = opt['recurrent_output_index'] self.recurrent_output_index = opt['recurrent_output_index']
self.recurrent_index = opt['recurrent_index'] self.recurrent_index = opt['recurrent_index']
self.recurrence = True
else:
self.recurrence = False
self.depth = opt['depth'] self.depth = opt['depth']
self.number_branches = opt['num_branches'] # Number of input branches to randomly choose for generation. This defines the output shape. self.number_branches = opt['num_branches'] # Number of input branches to randomly choose for generation. This defines the output shape.
self.multiscale_leaves = build_multiscale_patch_index_map(self.depth) self.multiscale_leaves = build_multiscale_patch_index_map(self.depth)
@ -52,6 +56,7 @@ class ProgressiveGeneratorInjector(Injector):
def feed_forward(self, gen, inputs, results, lq_input, recurrent_input): def feed_forward(self, gen, inputs, results, lq_input, recurrent_input):
ff_input = inputs.copy() ff_input = inputs.copy()
ff_input[self.input_lq_index] = lq_input ff_input[self.input_lq_index] = lq_input
if self.recurrence:
ff_input[self.recurrent_index] = recurrent_input ff_input[self.recurrent_index] = recurrent_input
with autocast(enabled=self.env['opt']['fp16']): with autocast(enabled=self.env['opt']['fp16']):
@ -61,7 +66,10 @@ class ProgressiveGeneratorInjector(Injector):
gen_out = [gen_out] gen_out = [gen_out]
for i, out_key in enumerate(self.output): for i, out_key in enumerate(self.output):
results[out_key].append(gen_out[i]) results[out_key].append(gen_out[i])
return gen_out[self.output_hq_index], gen_out[self.recurrent_output_index] recurrent = None
if self.recurrence:
recurrent = gen_out[self.recurrent_output_index]
return gen_out[self.output_hq_index], recurrent
def forward(self, state): def forward(self, state):
gen = self.env['generators'][self.gen_key] gen = self.env['generators'][self.gen_key]
@ -73,6 +81,7 @@ class ProgressiveGeneratorInjector(Injector):
inputs = [inputs] inputs = [inputs]
if not isinstance(self.output, list): if not isinstance(self.output, list):
output = [self.output] output = [self.output]
self.output = output
results = {} # A list of outputs produced by feeding each progressive lq input into the generator. results = {} # A list of outputs produced by feeding each progressive lq input into the generator.
results_hq = [] results_hq = []
for out_key in output: for out_key in output:
@ -91,6 +100,7 @@ class ProgressiveGeneratorInjector(Injector):
for link in chain: # Remember, `link` is a MultiscaleTreeNode. for link in chain: # Remember, `link` is a MultiscaleTreeNode.
top = int(link.top * h) top = int(link.top * h)
left = int(link.left * w) left = int(link.left * w)
if recurrent is not None:
recurrent = torch.nn.functional.interpolate(recurrent[:, :, top:top+h//2, left:left+w//2], scale_factor=2, mode="nearest") recurrent = torch.nn.functional.interpolate(recurrent[:, :, top:top+h//2, left:left+w//2], scale_factor=2, mode="nearest")
if self.feed_gen_output_into_input: if self.feed_gen_output_into_input:
top *= 2 top *= 2

View File

@ -13,7 +13,7 @@ from skimage import io
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_imgset_spsr_switched2_xlbatch.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_prog_mi1_rrdb_6bypass.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
@ -50,6 +50,7 @@ def main():
len(train_set), train_size)) len(train_set), train_size))
assert train_loader is not None assert train_loader is not None
'''
tq_ldr = tqdm(train_set.get_paths()) tq_ldr = tqdm(train_set.get_paths())
for path in tq_ldr: for path in tq_ldr:
try: try:
@ -58,6 +59,10 @@ def main():
except Exception as e: except Exception as e:
print("Error with %s" % (path,)) print("Error with %s" % (path,))
print(e) print(e)
'''
tq_ldr = tqdm(train_set)
for ds in tq_ldr:
pass
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -265,7 +265,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_mi1_rrdb_6bypass.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -23,8 +23,6 @@ def parse(opt_path, is_train=True):
for phase, dataset in opt['datasets'].items(): for phase, dataset in opt['datasets'].items():
phase = phase.split('_')[0] phase = phase.split('_')[0]
dataset['phase'] = phase dataset['phase'] = phase
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
dataset['scale'] = scale
is_lmdb = False is_lmdb = False
''' LMDB is not supported at this point with the mods I've been making. ''' LMDB is not supported at this point with the mods I've been making.
if dataset.get('dataroot_GT', None) is not None: if dataset.get('dataroot_GT', None) is not None:
@ -67,11 +65,6 @@ def parse(opt_path, is_train=True):
opt['path']['results_root'] = results_root opt['path']['results_root'] = results_root
opt['path']['log'] = results_root opt['path']['log'] = results_root
# network
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
if 'network_G' in opt.keys():
opt['network_G']['scale'] = scale
return opt return opt