Mods needed to support training a corruptor again:
- Allow original SPSRNet to have a specifiable block increment - Cleanup - Bug fixes in code that hasnt been touched in awhile.
This commit is contained in:
parent
bfdfaab911
commit
6657a406ac
|
@ -131,7 +131,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
|
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
|
||||||
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
|
||||||
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
||||||
|
|
||||||
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
||||||
|
@ -150,6 +150,10 @@ class ExtensibleTrainer(BaseModel):
|
||||||
# Iterate through the steps, performing them one at a time.
|
# Iterate through the steps, performing them one at a time.
|
||||||
state = self.dstate
|
state = self.dstate
|
||||||
for step_num, s in enumerate(self.steps):
|
for step_num, s in enumerate(self.steps):
|
||||||
|
# Skip steps if mod_step doesn't line up.
|
||||||
|
if 'mod_step' in s.opt.keys() and step % s.opt['mod_step'] != 0:
|
||||||
|
continue
|
||||||
|
|
||||||
# Only set requires_grad=True for the network being trained.
|
# Only set requires_grad=True for the network being trained.
|
||||||
nets_to_train = s.get_networks_trained()
|
nets_to_train = s.get_networks_trained()
|
||||||
enabled = 0
|
enabled = 0
|
||||||
|
|
|
@ -83,15 +83,16 @@ class ImageGradientNoPadding(nn.Module):
|
||||||
|
|
||||||
class SPSRNet(nn.Module):
|
class SPSRNet(nn.Module):
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
|
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
|
||||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
|
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', bl_inc=5):
|
||||||
super(SPSRNet, self).__init__()
|
super(SPSRNet, self).__init__()
|
||||||
|
|
||||||
|
self.bl_inc = bl_inc
|
||||||
n_upscale = int(math.log(upscale, 2))
|
n_upscale = int(math.log(upscale, 2))
|
||||||
|
|
||||||
if upscale == 3:
|
if upscale == 3:
|
||||||
n_upscale = 1
|
n_upscale = 1
|
||||||
|
|
||||||
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
fea_conv = B.conv_block(in_nc + 1, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||||
rb_blocks = [RRDB(nf, gc=32) for _ in range(nb)]
|
rb_blocks = [RRDB(nf, gc=32) for _ in range(nb)]
|
||||||
|
|
||||||
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
||||||
|
@ -161,31 +162,33 @@ class SPSRNet(nn.Module):
|
||||||
self._branch_pretrain_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
|
self._branch_pretrain_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: torch.Tensor):
|
||||||
|
|
||||||
x_grad = self.get_g_nopadding(x)
|
x_grad = self.get_g_nopadding(x)
|
||||||
|
|
||||||
|
b, f, w, h = x.shape
|
||||||
|
x = torch.cat([x, torch.randn(b, 1, w, h, device=x.get_device())], dim=1)
|
||||||
x = self.model[0](x)
|
x = self.model[0](x)
|
||||||
|
|
||||||
x, block_list = self.model[1](x)
|
x, block_list = self.model[1](x)
|
||||||
|
|
||||||
x_ori = x
|
x_ori = x
|
||||||
for i in range(5):
|
for i in range(self.bl_inc):
|
||||||
x = block_list[i](x)
|
x = block_list[i](x)
|
||||||
x_fea1 = x
|
x_fea1 = x
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(self.bl_inc):
|
||||||
x = block_list[i+5](x)
|
x = block_list[i+self.bl_inc](x)
|
||||||
x_fea2 = x
|
x_fea2 = x
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(self.bl_inc):
|
||||||
x = block_list[i+10](x)
|
x = block_list[i+self.bl_inc*2](x)
|
||||||
x_fea3 = x
|
x_fea3 = x
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(self.bl_inc):
|
||||||
x = block_list[i+15](x)
|
x = block_list[i+self.bl_inc*3](x)
|
||||||
x_fea4 = x
|
x_fea4 = x
|
||||||
|
|
||||||
x = block_list[20:](x)
|
x = block_list[self.bl_inc*4:](x)
|
||||||
#short cut
|
#short cut
|
||||||
x = x_ori+x
|
x = x_ori+x
|
||||||
x= self.model[2:](x)
|
x= self.model[2:](x)
|
||||||
|
@ -228,7 +231,7 @@ class SPSRNet(nn.Module):
|
||||||
x_out = self._branch_pretrain_HR_conv1(x_out)
|
x_out = self._branch_pretrain_HR_conv1(x_out)
|
||||||
|
|
||||||
#########
|
#########
|
||||||
return x_out_branch, x_out, x_gradn
|
return x_out_branch, x_out, x_grad
|
||||||
|
|
||||||
|
|
||||||
class SwitchedSpsr(nn.Module):
|
class SwitchedSpsr(nn.Module):
|
||||||
|
|
|
@ -254,7 +254,7 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
x = x1 + rand_feature
|
x = x1 + rand_feature
|
||||||
|
|
||||||
if self.pre_transform:
|
if self.pre_transform:
|
||||||
x = self.pre_transform(*x)
|
x = self.pre_transform(x)
|
||||||
if not isinstance(x, tuple):
|
if not isinstance(x, tuple):
|
||||||
x = (x,)
|
x = (x,)
|
||||||
xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms]
|
xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms]
|
||||||
|
|
|
@ -41,34 +41,6 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
gen_scale = scale * initial_stride
|
gen_scale = scale * initial_stride
|
||||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, initial_stride=initial_stride)
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, initial_stride=initial_stride)
|
||||||
elif which_model == 'AssistedRRDBNet':
|
|
||||||
netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)
|
|
||||||
elif which_model == 'LowDimRRDBNet':
|
|
||||||
gen_scale = scale * opt_net['initial_stride']
|
|
||||||
rrdb = functools.partial(RRDBNet_arch.LowDimRRDB, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim'])
|
|
||||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride'])
|
|
||||||
elif which_model == 'PixRRDBNet':
|
|
||||||
block_f = None
|
|
||||||
if opt_net['attention']:
|
|
||||||
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
|
||||||
init_temperature=opt_net['temperature'],
|
|
||||||
final_temperature_step=opt_net['temperature_final_step'])
|
|
||||||
if opt_net['mhattention']:
|
|
||||||
block_f = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=8, num_heads=2, nf=opt_net['nf'], gc=opt_net['gc'],
|
|
||||||
init_temperature=opt_net['temperature'],
|
|
||||||
final_temperature_step=opt_net['temperature_final_step'])
|
|
||||||
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
|
|
||||||
elif which_model == "ConfigurableSwitchedResidualGenerator":
|
|
||||||
netG = srg1.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_growths=opt_net['switch_growths'],
|
|
||||||
switch_reductions=opt_net['switch_reductions'],
|
|
||||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
|
||||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
|
||||||
trans_filters_mid=opt_net['trans_filters_mid'],
|
|
||||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
|
||||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
|
||||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
|
||||||
elif which_model == "ConfigurableSwitchedResidualGenerator2":
|
elif which_model == "ConfigurableSwitchedResidualGenerator2":
|
||||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
||||||
switch_reductions=opt_net['switch_reductions'],
|
switch_reductions=opt_net['switch_reductions'],
|
||||||
|
@ -78,17 +50,6 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||||
elif which_model == "ConfigurableSwitchedResidualGenerator3":
|
|
||||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator3(base_filters=opt_net['base_filters'], trans_count=opt_net['trans_count'])
|
|
||||||
elif which_model == "NestedSwitchGenerator":
|
|
||||||
netG = ng.NestedSwitchedGenerator(switch_filters=opt_net['switch_filters'],
|
|
||||||
switch_reductions=opt_net['switch_reductions'],
|
|
||||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
|
||||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
|
||||||
transformation_filters=opt_net['transformation_filters'],
|
|
||||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
|
||||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
|
||||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
|
||||||
elif which_model == "ConfigurableSwitchedResidualGenerator4":
|
elif which_model == "ConfigurableSwitchedResidualGenerator4":
|
||||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator4(switch_filters=opt_net['switch_filters'],
|
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator4(switch_filters=opt_net['switch_filters'],
|
||||||
switch_reductions=opt_net['switch_reductions'],
|
switch_reductions=opt_net['switch_reductions'],
|
||||||
|
@ -98,26 +59,15 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||||
elif which_model == "ProgressiveSRG2":
|
|
||||||
netG = psrg.GrowingSRGBase(progressive_step_schedule=opt_net['schedule'], switch_reductions=opt_net['reductions'],
|
|
||||||
growth_fade_in_steps=opt_net['fade_in_steps'], switch_filters=opt_net['switch_filters'],
|
|
||||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
|
||||||
trans_layers=opt_net['trans_layers'], transformation_filters=opt_net['transformation_filters'],
|
|
||||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
|
||||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
|
|
||||||
start_step=opt_net['start_step'])
|
|
||||||
elif which_model == 'spsr_net':
|
elif which_model == 'spsr_net':
|
||||||
netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||||
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
|
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
|
||||||
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
|
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv', bl_inc=opt_net['bl_inc'])
|
||||||
if opt['is_train']:
|
if opt['is_train']:
|
||||||
arch_util.initialize_weights(netG, scale=.1)
|
arch_util.initialize_weights(netG, scale=.1)
|
||||||
elif which_model == 'spsr_net_improved':
|
elif which_model == 'spsr_net_improved':
|
||||||
netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||||
nb=opt_net['nb'], upscale=opt_net['scale'])
|
nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||||
elif which_model == 'spsr_net_improved_noskip':
|
|
||||||
netG = spsr.SPSRNetSimplifiedNoSkip(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
|
||||||
nb=opt_net['nb'], upscale=opt_net['scale'])
|
|
||||||
elif which_model == "spsr_switched":
|
elif which_model == "spsr_switched":
|
||||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||||
netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||||
|
|
|
@ -152,8 +152,6 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
l_mfake = self.criterion(d_mismatch_fake, False)
|
l_mfake = self.criterion(d_mismatch_fake, False)
|
||||||
l_total += l_mreal + l_mfake
|
l_total += l_mreal + l_mfake
|
||||||
self.metrics.append(("l_mismatch", l_mfake + l_mreal))
|
self.metrics.append(("l_mismatch", l_mfake + l_mreal))
|
||||||
self.metrics.append(("l_fake", l_fake))
|
|
||||||
self.metrics.append(("l_real", l_real))
|
|
||||||
return l_total
|
return l_total
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
return (self.criterion(d_real - torch.mean(d_fake), True) +
|
return (self.criterion(d_real - torch.mean(d_fake), True) +
|
||||||
|
|
|
@ -61,7 +61,7 @@ def forward_pass(model, output_dir, alteration_suffix=''):
|
||||||
model.feed_data(data, need_GT=need_GT)
|
model.feed_data(data, need_GT=need_GT)
|
||||||
model.test()
|
model.test()
|
||||||
|
|
||||||
visuals = model.get_current_visuals()['rlt'].cpu()
|
visuals = model.get_current_visuals(need_GT)['rlt'].cpu()
|
||||||
fea_loss = 0
|
fea_loss = 0
|
||||||
for i in range(visuals.shape[0]):
|
for i in range(visuals.shape[0]):
|
||||||
img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i]
|
img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i]
|
||||||
|
@ -76,7 +76,8 @@ def forward_pass(model, output_dir, alteration_suffix=''):
|
||||||
else:
|
else:
|
||||||
save_img_path = osp.join(output_dir, img_name + '.png')
|
save_img_path = osp.join(output_dir, img_name + '.png')
|
||||||
|
|
||||||
fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i])
|
if need_GT:
|
||||||
|
fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i])
|
||||||
|
|
||||||
util.save_img(sr_img, save_img_path)
|
util.save_img(sr_img, save_img_path)
|
||||||
return fea_loss
|
return fea_loss
|
||||||
|
@ -88,7 +89,7 @@ if __name__ == "__main__":
|
||||||
want_just_images = True
|
want_just_images = True
|
||||||
srg_analyze = False
|
srg_analyze = False
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/analyze_srg.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_gan.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_corrupt_imgset_rrdb.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user