diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 30c4f7aa..ed9129c1 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel import models.networks as networks import models.lr_scheduler as lr_scheduler -from .base_model import BaseModel +from models.base_model import BaseModel from models.loss import GANLoss from apex import amp @@ -150,8 +150,11 @@ class SRGANModel(BaseModel): for p in self.netD.parameters(): p.requires_grad = False - self.optimizer_G.zero_grad() - self.fake_H = self.netG(self.var_L) + if step > self.D_init_iters: + self.optimizer_G.zero_grad() + self.fake_H = self.netG(self.var_L) + else: + self.fake_H = self.pix if step % 50 == 0: for i in range(self.var_L.shape[0]): @@ -235,9 +238,9 @@ class SRGANModel(BaseModel): self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_g_total'] = l_g_total.item() - self.log_dict['l_d_real'] = l_d_real.item() - self.log_dict['l_d_fake'] = l_d_fake.item() - self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) + self.log_dict['l_d_real'] = l_d_real.item() + self.log_dict['l_d_fake'] = l_d_fake.item() + self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def test(self): self.netG.eval() diff --git a/codes/models/archs/HighToLowResNet.py b/codes/models/archs/HighToLowResNet.py index 470359f2..2989f4cc 100644 --- a/codes/models/archs/HighToLowResNet.py +++ b/codes/models/archs/HighToLowResNet.py @@ -22,9 +22,9 @@ class HighToLowResNet(nn.Module): # The first will be applied against the hi-res inputs and will have only 4 layers. # The second will be applied after half of the downscaling and will also have only 6 layers. # The final will be applied against the final resolution and will have all of the remaining layers. - self.trunk_hires = arch_util.make_layer(basic_block, 4) - self.trunk_medres = arch_util.make_layer(basic_block, 6) - self.trunk_lores = arch_util.make_layer(basic_block2, nb - 10) + self.trunk_hires = arch_util.make_layer(basic_block, 5) + self.trunk_medres = arch_util.make_layer(basic_block, 10) + self.trunk_lores = arch_util.make_layer(basic_block2, nb - 15) # downsampling if self.downscale == 4: diff --git a/codes/options/test/test_ESRGAN_vrp.yml b/codes/options/test/test_ESRGAN_vrp.yml index 403911b1..5c22ad0d 100644 --- a/codes/options/test/test_ESRGAN_vrp.yml +++ b/codes/options/test/test_ESRGAN_vrp.yml @@ -23,4 +23,4 @@ network_G: #### path path: - pretrain_model_G: ../experiments/ESRGANx4_blacked_ramped_feat/models/35000_G.pth + pretrain_model_G: ../experiments/ESRGANx4_blacked_for_adrianna/models/19500_G.pth diff --git a/codes/options/train/finetune_ESRGAN_blacked_for_adrianna.yml b/codes/options/train/finetune_ESRGAN_blacked_for_adrianna.yml index 60ef4e17..cc75f847 100644 --- a/codes/options/train/finetune_ESRGAN_blacked_for_adrianna.yml +++ b/codes/options/train/finetune_ESRGAN_blacked_for_adrianna.yml @@ -41,9 +41,9 @@ network_D: #### path path: - pretrain_model_G: ../experiments/blacked_ft_G.pth - pretrain_model_D: ../experiments/blacked_ft_D.pth - resume_state: ~ + pretrain_model_G: ../experiments/ESRGANx4_blacked_for_adrianna/models/15500_G.pth + pretrain_model_D: ../experiments/ESRGANx4_blacked_for_adrianna/models/15500_D.pth + resume_state: ../experiments/ESRGANx4_blacked_for_adrianna/training_state/15500.state strict_load: true #### training settings: learning rate scheme, loss @@ -71,7 +71,7 @@ train: feature_weight_decay_steps: 2500 feature_weight_minimum: !!float 1e-3 gan_type: ragan # gan | ragan - gan_weight: !!float 1e-2 + gan_weight: !!float 1e-1 D_update_ratio: 1 D_init_iters: 0 diff --git a/codes/options/train/finetune_corruptGAN_adrianna.yml b/codes/options/train/finetune_corruptGAN_adrianna.yml deleted file mode 100644 index df067da5..00000000 --- a/codes/options/train/finetune_corruptGAN_adrianna.yml +++ /dev/null @@ -1,84 +0,0 @@ -#### general settings -name: ESRGAN_adrianna_corrupt_finetune -use_tb_logger: true -model: corruptgan -distortion: downsample -scale: 4 -gpu_ids: [0] -amp_opt_level: O1 - -#### datasets -datasets: - train: - name: blacked - mode: downsample - dataroot_GT: ../datasets/blacked/train/hr - dataroot_LQ: ../datasets/adrianna/train/lr - mismatched_Data_OK: true - use_shuffle: true - n_workers: 4 # per GPU - batch_size: 16 - target_size: 64 - use_flip: false - use_rot: false - color: RGB - val: - name: blacked_val - mode: downsample - target_size: 64 - dataroot_GT: ../datasets/blacked/val/hr - dataroot_LQ: ../datasets/blacked/val/lr - -#### network structures -network_G: - which_model_G: HighToLowResNet - in_nc: 3 - out_nc: 3 - nf: 128 - nb: 30 -network_D: - which_model_D: discriminator_vgg_128 - in_nc: 3 - nf: 96 - -#### path -path: - pretrain_model_G: ../experiments/blacked_lqprn_corrupt_G.pth - pretrain_model_D: ../experiments/blacked_lqprn_corrupt_D.pth - resume_state: ~ - strict_load: true - -#### training settings: learning rate scheme, loss -train: - lr_G: !!float 1e-5 - weight_decay_G: 0 - beta1_G: 0.9 - beta2_G: 0.99 - lr_D: !!float 1e-5 - weight_decay_D: 0 - beta1_D: 0.9 - beta2_D: 0.99 - lr_scheme: MultiStepLR - - niter: 400000 - warmup_iter: -1 # no warm up - lr_steps: [1000, 2000, 3000] - lr_gamma: 0.5 - - pixel_criterion: l1 - pixel_weight: !!float 1e-2 - feature_criterion: l1 - feature_weight: 0 - gan_type: gan # gan | ragan - gan_weight: !!float 1e-1 - - D_update_ratio: 1 - D_init_iters: 0 - - manual_seed: 10 - val_freq: !!float 5e2 - -#### logger -logger: - print_freq: 50 - save_checkpoint_freq: !!float 5e2 diff --git a/codes/options/train/train_GAN_blacked_corrupt.yml b/codes/options/train/train_GAN_blacked_corrupt.yml index 30e8ed23..a9ffe77b 100644 --- a/codes/options/train/train_GAN_blacked_corrupt.yml +++ b/codes/options/train/train_GAN_blacked_corrupt.yml @@ -1,5 +1,5 @@ #### general settings -name: ESRGAN_blacked_corrupt_lqprn +name: corruptGAN_4k_lqprn_closeup use_tb_logger: true model: corruptgan distortion: downsample @@ -12,9 +12,9 @@ datasets: train: name: blacked mode: downsample - dataroot_GT: ../datasets/blacked/train/hr - dataroot_LQ: ../datasets/lqprn/train/lr - mismatched_Data_OK: false + dataroot_GT: K:\\4k6k\\4k_closeup\\hr + dataroot_LQ: E:\\4k6k\\adrianna\\for_training\\lr + mismatched_Data_OK: true use_shuffle: true n_workers: 4 # per GPU batch_size: 16 @@ -34,12 +34,12 @@ network_G: which_model_G: HighToLowResNet in_nc: 3 out_nc: 3 - nf: 128 - nb: 30 + nf: 64 + nb: 64 network_D: which_model_D: discriminator_vgg_128 in_nc: 3 - nf: 96 + nf: 64 #### path path: @@ -61,7 +61,7 @@ train: niter: 400000 warmup_iter: -1 # no warm up - lr_steps: [1000, 2000, 3500, 5000, 6500] + lr_steps: [4000, 8000, 12000, 15000, 20000] lr_gamma: 0.5 pixel_criterion: l1 @@ -71,8 +71,8 @@ train: gan_type: gan # gan | ragan gan_weight: !!float 1e-1 - D_update_ratio: 1 - D_init_iters: 0 + D_update_ratio: 2 + D_init_iters: 500 manual_seed: 10 val_freq: !!float 5e2 diff --git a/codes/test.py b/codes/test.py index 1586bef0..f1e74f6a 100644 --- a/codes/test.py +++ b/codes/test.py @@ -15,7 +15,7 @@ if __name__ == "__main__": #### options want_just_images = True parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_corrupt_vixen_adrianna.yml') + parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_vrp.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) diff --git a/codes/train.py b/codes/train.py index 2906463a..befd5d97 100644 --- a/codes/train.py +++ b/codes/train.py @@ -29,7 +29,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_corruptGAN_adrianna.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_GAN_blacked_corrupt.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)