diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 29ff0859..070e8913 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -131,7 +131,7 @@ class ExtensibleTrainer(BaseModel): self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0) self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] - input_ref = data['ref'] if 'ref' in data else data['GT'] + input_ref = data['ref'] if 'ref' in data.keys() else data['GT'] self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref} @@ -150,6 +150,10 @@ class ExtensibleTrainer(BaseModel): # Iterate through the steps, performing them one at a time. state = self.dstate for step_num, s in enumerate(self.steps): + # Skip steps if mod_step doesn't line up. + if 'mod_step' in s.opt.keys() and step % s.opt['mod_step'] != 0: + continue + # Only set requires_grad=True for the network being trained. nets_to_train = s.get_networks_trained() enabled = 0 diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 7a4acbc1..414b43ff 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -83,15 +83,16 @@ class ImageGradientNoPadding(nn.Module): class SPSRNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \ - act_type='leakyrelu', mode='CNA', upsample_mode='upconv'): + act_type='leakyrelu', mode='CNA', upsample_mode='upconv', bl_inc=5): super(SPSRNet, self).__init__() + self.bl_inc = bl_inc n_upscale = int(math.log(upscale, 2)) if upscale == 3: n_upscale = 1 - fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) + fea_conv = B.conv_block(in_nc + 1, nf, kernel_size=3, norm_type=None, act_type=None) rb_blocks = [RRDB(nf, gc=32) for _ in range(nb)] LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) @@ -161,31 +162,33 @@ class SPSRNet(nn.Module): self._branch_pretrain_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None) - def forward(self, x): - + def forward(self, x: torch.Tensor): x_grad = self.get_g_nopadding(x) + + b, f, w, h = x.shape + x = torch.cat([x, torch.randn(b, 1, w, h, device=x.get_device())], dim=1) x = self.model[0](x) x, block_list = self.model[1](x) x_ori = x - for i in range(5): + for i in range(self.bl_inc): x = block_list[i](x) x_fea1 = x - for i in range(5): - x = block_list[i+5](x) + for i in range(self.bl_inc): + x = block_list[i+self.bl_inc](x) x_fea2 = x - for i in range(5): - x = block_list[i+10](x) + for i in range(self.bl_inc): + x = block_list[i+self.bl_inc*2](x) x_fea3 = x - for i in range(5): - x = block_list[i+15](x) + for i in range(self.bl_inc): + x = block_list[i+self.bl_inc*3](x) x_fea4 = x - x = block_list[20:](x) + x = block_list[self.bl_inc*4:](x) #short cut x = x_ori+x x= self.model[2:](x) @@ -228,7 +231,7 @@ class SPSRNet(nn.Module): x_out = self._branch_pretrain_HR_conv1(x_out) ######### - return x_out_branch, x_out, x_gradn + return x_out_branch, x_out, x_grad class SwitchedSpsr(nn.Module): diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index aa0d740a..f6878aab 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -254,7 +254,7 @@ class ConfigurableSwitchComputer(nn.Module): x = x1 + rand_feature if self.pre_transform: - x = self.pre_transform(*x) + x = self.pre_transform(x) if not isinstance(x, tuple): x = (x,) xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms] diff --git a/codes/models/networks.py b/codes/models/networks.py index fd16e950..d9740ab4 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -41,34 +41,6 @@ def define_G(opt, net_key='network_G', scale=None): gen_scale = scale * initial_stride netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, initial_stride=initial_stride) - elif which_model == 'AssistedRRDBNet': - netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], scale=scale) - elif which_model == 'LowDimRRDBNet': - gen_scale = scale * opt_net['initial_stride'] - rrdb = functools.partial(RRDBNet_arch.LowDimRRDB, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim']) - netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride']) - elif which_model == 'PixRRDBNet': - block_f = None - if opt_net['attention']: - block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'], - init_temperature=opt_net['temperature'], - final_temperature_step=opt_net['temperature_final_step']) - if opt_net['mhattention']: - block_f = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=8, num_heads=2, nf=opt_net['nf'], gc=opt_net['gc'], - init_temperature=opt_net['temperature'], - final_temperature_step=opt_net['temperature_final_step']) - netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f) - elif which_model == "ConfigurableSwitchedResidualGenerator": - netG = srg1.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_growths=opt_net['switch_growths'], - switch_reductions=opt_net['switch_reductions'], - switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], - trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], - trans_filters_mid=opt_net['trans_filters_mid'], - initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], - heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], - upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) elif which_model == "ConfigurableSwitchedResidualGenerator2": netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'], @@ -78,17 +50,6 @@ def define_G(opt, net_key='network_G', scale=None): initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) - elif which_model == "ConfigurableSwitchedResidualGenerator3": - netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator3(base_filters=opt_net['base_filters'], trans_count=opt_net['trans_count']) - elif which_model == "NestedSwitchGenerator": - netG = ng.NestedSwitchedGenerator(switch_filters=opt_net['switch_filters'], - switch_reductions=opt_net['switch_reductions'], - switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], - trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], - transformation_filters=opt_net['transformation_filters'], - initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], - heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], - upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) elif which_model == "ConfigurableSwitchedResidualGenerator4": netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator4(switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'], @@ -98,26 +59,15 @@ def define_G(opt, net_key='network_G', scale=None): initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) - elif which_model == "ProgressiveSRG2": - netG = psrg.GrowingSRGBase(progressive_step_schedule=opt_net['schedule'], switch_reductions=opt_net['reductions'], - growth_fade_in_steps=opt_net['fade_in_steps'], switch_filters=opt_net['switch_filters'], - switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], - trans_layers=opt_net['trans_layers'], transformation_filters=opt_net['transformation_filters'], - initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], - upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'], - start_step=opt_net['start_step']) elif which_model == 'spsr_net': netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], - act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') + act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv', bl_inc=opt_net['bl_inc']) if opt['is_train']: arch_util.initialize_weights(netG, scale=.1) elif which_model == 'spsr_net_improved': netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) - elif which_model == 'spsr_net_improved_noskip': - netG = spsr.SPSRNetSimplifiedNoSkip(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], - nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == "spsr_switched": xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 77c53581..128f4e41 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -152,8 +152,6 @@ class DiscriminatorGanLoss(ConfigurableLoss): l_mfake = self.criterion(d_mismatch_fake, False) l_total += l_mreal + l_mfake self.metrics.append(("l_mismatch", l_mfake + l_mreal)) - self.metrics.append(("l_fake", l_fake)) - self.metrics.append(("l_real", l_real)) return l_total elif self.opt['gan_type'] == 'ragan': return (self.criterion(d_real - torch.mean(d_fake), True) + diff --git a/codes/test.py b/codes/test.py index 7d418ceb..4c96f14c 100644 --- a/codes/test.py +++ b/codes/test.py @@ -61,7 +61,7 @@ def forward_pass(model, output_dir, alteration_suffix=''): model.feed_data(data, need_GT=need_GT) model.test() - visuals = model.get_current_visuals()['rlt'].cpu() + visuals = model.get_current_visuals(need_GT)['rlt'].cpu() fea_loss = 0 for i in range(visuals.shape[0]): img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i] @@ -76,7 +76,8 @@ def forward_pass(model, output_dir, alteration_suffix=''): else: save_img_path = osp.join(output_dir, img_name + '.png') - fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i]) + if need_GT: + fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i]) util.save_img(sr_img, save_img_path) return fea_loss @@ -88,7 +89,7 @@ if __name__ == "__main__": want_just_images = True srg_analyze = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/analyze_srg.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) diff --git a/codes/train.py b/codes/train.py index f4014552..6f63b7d6 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_gan.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_corrupt_imgset_rrdb.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)