forked from mrq/DL-Art-School
Mods needed to support training a corruptor again:
- Allow original SPSRNet to have a specifiable block increment - Cleanup - Bug fixes in code that hasnt been touched in awhile.
This commit is contained in:
parent
bfdfaab911
commit
6657a406ac
|
@ -131,7 +131,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
|
||||
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
|
||||
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
||||
|
||||
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
||||
|
@ -150,6 +150,10 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Iterate through the steps, performing them one at a time.
|
||||
state = self.dstate
|
||||
for step_num, s in enumerate(self.steps):
|
||||
# Skip steps if mod_step doesn't line up.
|
||||
if 'mod_step' in s.opt.keys() and step % s.opt['mod_step'] != 0:
|
||||
continue
|
||||
|
||||
# Only set requires_grad=True for the network being trained.
|
||||
nets_to_train = s.get_networks_trained()
|
||||
enabled = 0
|
||||
|
|
|
@ -83,15 +83,16 @@ class ImageGradientNoPadding(nn.Module):
|
|||
|
||||
class SPSRNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', bl_inc=5):
|
||||
super(SPSRNet, self).__init__()
|
||||
|
||||
self.bl_inc = bl_inc
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
fea_conv = B.conv_block(in_nc + 1, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
rb_blocks = [RRDB(nf, gc=32) for _ in range(nb)]
|
||||
|
||||
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
||||
|
@ -161,31 +162,33 @@ class SPSRNet(nn.Module):
|
|||
self._branch_pretrain_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
|
||||
b, f, w, h = x.shape
|
||||
x = torch.cat([x, torch.randn(b, 1, w, h, device=x.get_device())], dim=1)
|
||||
x = self.model[0](x)
|
||||
|
||||
x, block_list = self.model[1](x)
|
||||
|
||||
x_ori = x
|
||||
for i in range(5):
|
||||
for i in range(self.bl_inc):
|
||||
x = block_list[i](x)
|
||||
x_fea1 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i+5](x)
|
||||
for i in range(self.bl_inc):
|
||||
x = block_list[i+self.bl_inc](x)
|
||||
x_fea2 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i+10](x)
|
||||
for i in range(self.bl_inc):
|
||||
x = block_list[i+self.bl_inc*2](x)
|
||||
x_fea3 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i+15](x)
|
||||
for i in range(self.bl_inc):
|
||||
x = block_list[i+self.bl_inc*3](x)
|
||||
x_fea4 = x
|
||||
|
||||
x = block_list[20:](x)
|
||||
x = block_list[self.bl_inc*4:](x)
|
||||
#short cut
|
||||
x = x_ori+x
|
||||
x= self.model[2:](x)
|
||||
|
@ -228,7 +231,7 @@ class SPSRNet(nn.Module):
|
|||
x_out = self._branch_pretrain_HR_conv1(x_out)
|
||||
|
||||
#########
|
||||
return x_out_branch, x_out, x_gradn
|
||||
return x_out_branch, x_out, x_grad
|
||||
|
||||
|
||||
class SwitchedSpsr(nn.Module):
|
||||
|
|
|
@ -254,7 +254,7 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
x = x1 + rand_feature
|
||||
|
||||
if self.pre_transform:
|
||||
x = self.pre_transform(*x)
|
||||
x = self.pre_transform(x)
|
||||
if not isinstance(x, tuple):
|
||||
x = (x,)
|
||||
xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms]
|
||||
|
|
|
@ -41,34 +41,6 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
gen_scale = scale * initial_stride
|
||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, initial_stride=initial_stride)
|
||||
elif which_model == 'AssistedRRDBNet':
|
||||
netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)
|
||||
elif which_model == 'LowDimRRDBNet':
|
||||
gen_scale = scale * opt_net['initial_stride']
|
||||
rrdb = functools.partial(RRDBNet_arch.LowDimRRDB, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim'])
|
||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride'])
|
||||
elif which_model == 'PixRRDBNet':
|
||||
block_f = None
|
||||
if opt_net['attention']:
|
||||
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
init_temperature=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step'])
|
||||
if opt_net['mhattention']:
|
||||
block_f = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=8, num_heads=2, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
init_temperature=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step'])
|
||||
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
|
||||
elif which_model == "ConfigurableSwitchedResidualGenerator":
|
||||
netG = srg1.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_growths=opt_net['switch_growths'],
|
||||
switch_reductions=opt_net['switch_reductions'],
|
||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||
trans_filters_mid=opt_net['trans_filters_mid'],
|
||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||
elif which_model == "ConfigurableSwitchedResidualGenerator2":
|
||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
||||
switch_reductions=opt_net['switch_reductions'],
|
||||
|
@ -78,17 +50,6 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||
elif which_model == "ConfigurableSwitchedResidualGenerator3":
|
||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator3(base_filters=opt_net['base_filters'], trans_count=opt_net['trans_count'])
|
||||
elif which_model == "NestedSwitchGenerator":
|
||||
netG = ng.NestedSwitchedGenerator(switch_filters=opt_net['switch_filters'],
|
||||
switch_reductions=opt_net['switch_reductions'],
|
||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||
transformation_filters=opt_net['transformation_filters'],
|
||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||
elif which_model == "ConfigurableSwitchedResidualGenerator4":
|
||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator4(switch_filters=opt_net['switch_filters'],
|
||||
switch_reductions=opt_net['switch_reductions'],
|
||||
|
@ -98,26 +59,15 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||
elif which_model == "ProgressiveSRG2":
|
||||
netG = psrg.GrowingSRGBase(progressive_step_schedule=opt_net['schedule'], switch_reductions=opt_net['reductions'],
|
||||
growth_fade_in_steps=opt_net['fade_in_steps'], switch_filters=opt_net['switch_filters'],
|
||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||
trans_layers=opt_net['trans_layers'], transformation_filters=opt_net['transformation_filters'],
|
||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
|
||||
start_step=opt_net['start_step'])
|
||||
elif which_model == 'spsr_net':
|
||||
netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
|
||||
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
|
||||
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv', bl_inc=opt_net['bl_inc'])
|
||||
if opt['is_train']:
|
||||
arch_util.initialize_weights(netG, scale=.1)
|
||||
elif which_model == 'spsr_net_improved':
|
||||
netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||
nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||
elif which_model == 'spsr_net_improved_noskip':
|
||||
netG = spsr.SPSRNetSimplifiedNoSkip(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||
nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||
elif which_model == "spsr_switched":
|
||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||
netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||
|
|
|
@ -152,8 +152,6 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
l_mfake = self.criterion(d_mismatch_fake, False)
|
||||
l_total += l_mreal + l_mfake
|
||||
self.metrics.append(("l_mismatch", l_mfake + l_mreal))
|
||||
self.metrics.append(("l_fake", l_fake))
|
||||
self.metrics.append(("l_real", l_real))
|
||||
return l_total
|
||||
elif self.opt['gan_type'] == 'ragan':
|
||||
return (self.criterion(d_real - torch.mean(d_fake), True) +
|
||||
|
|
|
@ -61,7 +61,7 @@ def forward_pass(model, output_dir, alteration_suffix=''):
|
|||
model.feed_data(data, need_GT=need_GT)
|
||||
model.test()
|
||||
|
||||
visuals = model.get_current_visuals()['rlt'].cpu()
|
||||
visuals = model.get_current_visuals(need_GT)['rlt'].cpu()
|
||||
fea_loss = 0
|
||||
for i in range(visuals.shape[0]):
|
||||
img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i]
|
||||
|
@ -76,7 +76,8 @@ def forward_pass(model, output_dir, alteration_suffix=''):
|
|||
else:
|
||||
save_img_path = osp.join(output_dir, img_name + '.png')
|
||||
|
||||
fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i])
|
||||
if need_GT:
|
||||
fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i])
|
||||
|
||||
util.save_img(sr_img, save_img_path)
|
||||
return fea_loss
|
||||
|
@ -88,7 +89,7 @@ if __name__ == "__main__":
|
|||
want_just_images = True
|
||||
srg_analyze = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/analyze_srg.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_gan.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_corrupt_imgset_rrdb.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user