Mods needed to support training a corruptor again:

- Allow original SPSRNet to have a specifiable block increment
- Cleanup
- Bug fixes in code that hasnt been touched in awhile.
This commit is contained in:
James Betker 2020-09-04 15:33:39 -06:00
parent bfdfaab911
commit 6657a406ac
7 changed files with 28 additions and 72 deletions

View File

@ -131,7 +131,7 @@ class ExtensibleTrainer(BaseModel):
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data else data['GT']
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
@ -150,6 +150,10 @@ class ExtensibleTrainer(BaseModel):
# Iterate through the steps, performing them one at a time.
state = self.dstate
for step_num, s in enumerate(self.steps):
# Skip steps if mod_step doesn't line up.
if 'mod_step' in s.opt.keys() and step % s.opt['mod_step'] != 0:
continue
# Only set requires_grad=True for the network being trained.
nets_to_train = s.get_networks_trained()
enabled = 0

View File

@ -83,15 +83,16 @@ class ImageGradientNoPadding(nn.Module):
class SPSRNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', bl_inc=5):
super(SPSRNet, self).__init__()
self.bl_inc = bl_inc
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
fea_conv = B.conv_block(in_nc + 1, nf, kernel_size=3, norm_type=None, act_type=None)
rb_blocks = [RRDB(nf, gc=32) for _ in range(nb)]
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
@ -161,31 +162,33 @@ class SPSRNet(nn.Module):
self._branch_pretrain_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
def forward(self, x):
def forward(self, x: torch.Tensor):
x_grad = self.get_g_nopadding(x)
b, f, w, h = x.shape
x = torch.cat([x, torch.randn(b, 1, w, h, device=x.get_device())], dim=1)
x = self.model[0](x)
x, block_list = self.model[1](x)
x_ori = x
for i in range(5):
for i in range(self.bl_inc):
x = block_list[i](x)
x_fea1 = x
for i in range(5):
x = block_list[i+5](x)
for i in range(self.bl_inc):
x = block_list[i+self.bl_inc](x)
x_fea2 = x
for i in range(5):
x = block_list[i+10](x)
for i in range(self.bl_inc):
x = block_list[i+self.bl_inc*2](x)
x_fea3 = x
for i in range(5):
x = block_list[i+15](x)
for i in range(self.bl_inc):
x = block_list[i+self.bl_inc*3](x)
x_fea4 = x
x = block_list[20:](x)
x = block_list[self.bl_inc*4:](x)
#short cut
x = x_ori+x
x= self.model[2:](x)
@ -228,7 +231,7 @@ class SPSRNet(nn.Module):
x_out = self._branch_pretrain_HR_conv1(x_out)
#########
return x_out_branch, x_out, x_gradn
return x_out_branch, x_out, x_grad
class SwitchedSpsr(nn.Module):

View File

@ -254,7 +254,7 @@ class ConfigurableSwitchComputer(nn.Module):
x = x1 + rand_feature
if self.pre_transform:
x = self.pre_transform(*x)
x = self.pre_transform(x)
if not isinstance(x, tuple):
x = (x,)
xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms]

View File

@ -41,34 +41,6 @@ def define_G(opt, net_key='network_G', scale=None):
gen_scale = scale * initial_stride
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, initial_stride=initial_stride)
elif which_model == 'AssistedRRDBNet':
netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)
elif which_model == 'LowDimRRDBNet':
gen_scale = scale * opt_net['initial_stride']
rrdb = functools.partial(RRDBNet_arch.LowDimRRDB, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim'])
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride'])
elif which_model == 'PixRRDBNet':
block_f = None
if opt_net['attention']:
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
init_temperature=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step'])
if opt_net['mhattention']:
block_f = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=8, num_heads=2, nf=opt_net['nf'], gc=opt_net['gc'],
init_temperature=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step'])
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
elif which_model == "ConfigurableSwitchedResidualGenerator":
netG = srg1.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_growths=opt_net['switch_growths'],
switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
trans_filters_mid=opt_net['trans_filters_mid'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == "ConfigurableSwitchedResidualGenerator2":
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],
@ -78,17 +50,6 @@ def define_G(opt, net_key='network_G', scale=None):
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == "ConfigurableSwitchedResidualGenerator3":
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator3(base_filters=opt_net['base_filters'], trans_count=opt_net['trans_count'])
elif which_model == "NestedSwitchGenerator":
netG = ng.NestedSwitchedGenerator(switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
transformation_filters=opt_net['transformation_filters'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == "ConfigurableSwitchedResidualGenerator4":
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator4(switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],
@ -98,26 +59,15 @@ def define_G(opt, net_key='network_G', scale=None):
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == "ProgressiveSRG2":
netG = psrg.GrowingSRGBase(progressive_step_schedule=opt_net['schedule'], switch_reductions=opt_net['reductions'],
growth_fade_in_steps=opt_net['fade_in_steps'], switch_filters=opt_net['switch_filters'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_layers=opt_net['trans_layers'], transformation_filters=opt_net['transformation_filters'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
start_step=opt_net['start_step'])
elif which_model == 'spsr_net':
netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv', bl_inc=opt_net['bl_inc'])
if opt['is_train']:
arch_util.initialize_weights(netG, scale=.1)
elif which_model == 'spsr_net_improved':
netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
nb=opt_net['nb'], upscale=opt_net['scale'])
elif which_model == 'spsr_net_improved_noskip':
netG = spsr.SPSRNetSimplifiedNoSkip(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
nb=opt_net['nb'], upscale=opt_net['scale'])
elif which_model == "spsr_switched":
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],

View File

@ -152,8 +152,6 @@ class DiscriminatorGanLoss(ConfigurableLoss):
l_mfake = self.criterion(d_mismatch_fake, False)
l_total += l_mreal + l_mfake
self.metrics.append(("l_mismatch", l_mfake + l_mreal))
self.metrics.append(("l_fake", l_fake))
self.metrics.append(("l_real", l_real))
return l_total
elif self.opt['gan_type'] == 'ragan':
return (self.criterion(d_real - torch.mean(d_fake), True) +

View File

@ -61,7 +61,7 @@ def forward_pass(model, output_dir, alteration_suffix=''):
model.feed_data(data, need_GT=need_GT)
model.test()
visuals = model.get_current_visuals()['rlt'].cpu()
visuals = model.get_current_visuals(need_GT)['rlt'].cpu()
fea_loss = 0
for i in range(visuals.shape[0]):
img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i]
@ -76,7 +76,8 @@ def forward_pass(model, output_dir, alteration_suffix=''):
else:
save_img_path = osp.join(output_dir, img_name + '.png')
fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i])
if need_GT:
fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i])
util.save_img(sr_img, save_img_path)
return fea_loss
@ -88,7 +89,7 @@ if __name__ == "__main__":
want_just_images = True
srg_analyze = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/analyze_srg.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_gan.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_corrupt_imgset_rrdb.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)