forked from mrq/DL-Art-School
Fixes and additional support for progressive zoom
This commit is contained in:
parent
a3918fa808
commit
74738489b9
|
@ -22,7 +22,7 @@ class MultiScaleDataset(data.Dataset):
|
||||||
self.num_scales = self.opt['num_scales']
|
self.num_scales = self.opt['num_scales']
|
||||||
self.hq_size_cap = self.tile_size * 2 ** self.num_scales
|
self.hq_size_cap = self.tile_size * 2 ** self.num_scales
|
||||||
self.scale = self.opt['scale']
|
self.scale = self.opt['scale']
|
||||||
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1])
|
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1 for _ in opt['paths']])
|
||||||
self.corruptor = ImageCorruptor(opt)
|
self.corruptor = ImageCorruptor(opt)
|
||||||
|
|
||||||
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
||||||
|
|
|
@ -267,12 +267,13 @@ class ExtensibleTrainer(BaseModel):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# This can happen one of two ways: Either a 'validation injector' is provided, in which case we run that.
|
# This can happen one of two ways: Either a 'validation injector' is provided, in which case we run that.
|
||||||
# Or, we run the entire chain of steps in "train" mode and use eval.output_state.
|
# Or, we run the entire chain of steps in "train" mode and use eval.output_state.
|
||||||
if 'injector' in self.opt['eval'].keys():
|
if 'injectors' in self.opt['eval'].keys():
|
||||||
# Need to move from mega_batch mode to batch mode (remove chunks)
|
|
||||||
state = {}
|
state = {}
|
||||||
|
for inj in self.opt['eval']['injectors'].values():
|
||||||
|
# Need to move from mega_batch mode to batch mode (remove chunks)
|
||||||
for k, v in self.dstate.items():
|
for k, v in self.dstate.items():
|
||||||
state[k] = v[0]
|
state[k] = v[0]
|
||||||
inj = create_injector(self.opt['eval']['injector'], self.env)
|
inj = create_injector(inj, self.env)
|
||||||
state.update(inj(state))
|
state.update(inj(state))
|
||||||
else:
|
else:
|
||||||
# Iterate through the steps, performing them one at a time.
|
# Iterate through the steps, performing them one at a time.
|
||||||
|
|
|
@ -188,6 +188,8 @@ class RRDBNet(nn.Module):
|
||||||
if ref is None:
|
if ref is None:
|
||||||
ref = torch.zeros_like(x_lg)
|
ref = torch.zeros_like(x_lg)
|
||||||
x_lg = torch.cat([x_lg, ref], dim=1)
|
x_lg = torch.cat([x_lg, ref], dim=1)
|
||||||
|
else:
|
||||||
|
x_lg = x
|
||||||
feat = self.conv_first(x_lg)
|
feat = self.conv_first(x_lg)
|
||||||
body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat))
|
body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat))
|
||||||
feat = feat + body_feat
|
feat = feat + body_feat
|
||||||
|
|
|
@ -27,8 +27,12 @@ class ProgressiveGeneratorInjector(Injector):
|
||||||
self.hq_output_key = opt['hq_output'] # The key where HQ images corresponding with generated images are stored.
|
self.hq_output_key = opt['hq_output'] # The key where HQ images corresponding with generated images are stored.
|
||||||
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
||||||
self.output_hq_index = opt['output_hq_index']
|
self.output_hq_index = opt['output_hq_index']
|
||||||
|
if 'recurrent_output_index' in opt.keys():
|
||||||
self.recurrent_output_index = opt['recurrent_output_index']
|
self.recurrent_output_index = opt['recurrent_output_index']
|
||||||
self.recurrent_index = opt['recurrent_index']
|
self.recurrent_index = opt['recurrent_index']
|
||||||
|
self.recurrence = True
|
||||||
|
else:
|
||||||
|
self.recurrence = False
|
||||||
self.depth = opt['depth']
|
self.depth = opt['depth']
|
||||||
self.number_branches = opt['num_branches'] # Number of input branches to randomly choose for generation. This defines the output shape.
|
self.number_branches = opt['num_branches'] # Number of input branches to randomly choose for generation. This defines the output shape.
|
||||||
self.multiscale_leaves = build_multiscale_patch_index_map(self.depth)
|
self.multiscale_leaves = build_multiscale_patch_index_map(self.depth)
|
||||||
|
@ -52,6 +56,7 @@ class ProgressiveGeneratorInjector(Injector):
|
||||||
def feed_forward(self, gen, inputs, results, lq_input, recurrent_input):
|
def feed_forward(self, gen, inputs, results, lq_input, recurrent_input):
|
||||||
ff_input = inputs.copy()
|
ff_input = inputs.copy()
|
||||||
ff_input[self.input_lq_index] = lq_input
|
ff_input[self.input_lq_index] = lq_input
|
||||||
|
if self.recurrence:
|
||||||
ff_input[self.recurrent_index] = recurrent_input
|
ff_input[self.recurrent_index] = recurrent_input
|
||||||
|
|
||||||
with autocast(enabled=self.env['opt']['fp16']):
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
|
@ -61,7 +66,10 @@ class ProgressiveGeneratorInjector(Injector):
|
||||||
gen_out = [gen_out]
|
gen_out = [gen_out]
|
||||||
for i, out_key in enumerate(self.output):
|
for i, out_key in enumerate(self.output):
|
||||||
results[out_key].append(gen_out[i])
|
results[out_key].append(gen_out[i])
|
||||||
return gen_out[self.output_hq_index], gen_out[self.recurrent_output_index]
|
recurrent = None
|
||||||
|
if self.recurrence:
|
||||||
|
recurrent = gen_out[self.recurrent_output_index]
|
||||||
|
return gen_out[self.output_hq_index], recurrent
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.gen_key]
|
gen = self.env['generators'][self.gen_key]
|
||||||
|
@ -73,6 +81,7 @@ class ProgressiveGeneratorInjector(Injector):
|
||||||
inputs = [inputs]
|
inputs = [inputs]
|
||||||
if not isinstance(self.output, list):
|
if not isinstance(self.output, list):
|
||||||
output = [self.output]
|
output = [self.output]
|
||||||
|
self.output = output
|
||||||
results = {} # A list of outputs produced by feeding each progressive lq input into the generator.
|
results = {} # A list of outputs produced by feeding each progressive lq input into the generator.
|
||||||
results_hq = []
|
results_hq = []
|
||||||
for out_key in output:
|
for out_key in output:
|
||||||
|
@ -91,6 +100,7 @@ class ProgressiveGeneratorInjector(Injector):
|
||||||
for link in chain: # Remember, `link` is a MultiscaleTreeNode.
|
for link in chain: # Remember, `link` is a MultiscaleTreeNode.
|
||||||
top = int(link.top * h)
|
top = int(link.top * h)
|
||||||
left = int(link.left * w)
|
left = int(link.left * w)
|
||||||
|
if recurrent is not None:
|
||||||
recurrent = torch.nn.functional.interpolate(recurrent[:, :, top:top+h//2, left:left+w//2], scale_factor=2, mode="nearest")
|
recurrent = torch.nn.functional.interpolate(recurrent[:, :, top:top+h//2, left:left+w//2], scale_factor=2, mode="nearest")
|
||||||
if self.feed_gen_output_into_input:
|
if self.feed_gen_output_into_input:
|
||||||
top *= 2
|
top *= 2
|
||||||
|
|
|
@ -13,7 +13,7 @@ from skimage import io
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_imgset_spsr_switched2_xlbatch.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_prog_mi1_rrdb_6bypass.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
@ -50,6 +50,7 @@ def main():
|
||||||
len(train_set), train_size))
|
len(train_set), train_size))
|
||||||
assert train_loader is not None
|
assert train_loader is not None
|
||||||
|
|
||||||
|
'''
|
||||||
tq_ldr = tqdm(train_set.get_paths())
|
tq_ldr = tqdm(train_set.get_paths())
|
||||||
for path in tq_ldr:
|
for path in tq_ldr:
|
||||||
try:
|
try:
|
||||||
|
@ -58,6 +59,10 @@ def main():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error with %s" % (path,))
|
print("Error with %s" % (path,))
|
||||||
print(e)
|
print(e)
|
||||||
|
'''
|
||||||
|
tq_ldr = tqdm(train_set)
|
||||||
|
for ds in tq_ldr:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -265,7 +265,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_mi1_rrdb_6bypass.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -23,8 +23,6 @@ def parse(opt_path, is_train=True):
|
||||||
for phase, dataset in opt['datasets'].items():
|
for phase, dataset in opt['datasets'].items():
|
||||||
phase = phase.split('_')[0]
|
phase = phase.split('_')[0]
|
||||||
dataset['phase'] = phase
|
dataset['phase'] = phase
|
||||||
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
|
|
||||||
dataset['scale'] = scale
|
|
||||||
is_lmdb = False
|
is_lmdb = False
|
||||||
''' LMDB is not supported at this point with the mods I've been making.
|
''' LMDB is not supported at this point with the mods I've been making.
|
||||||
if dataset.get('dataroot_GT', None) is not None:
|
if dataset.get('dataroot_GT', None) is not None:
|
||||||
|
@ -67,11 +65,6 @@ def parse(opt_path, is_train=True):
|
||||||
opt['path']['results_root'] = results_root
|
opt['path']['results_root'] = results_root
|
||||||
opt['path']['log'] = results_root
|
opt['path']['log'] = results_root
|
||||||
|
|
||||||
# network
|
|
||||||
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
|
|
||||||
if 'network_G' in opt.keys():
|
|
||||||
opt['network_G']['scale'] = scale
|
|
||||||
|
|
||||||
return opt
|
return opt
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user