diff --git a/codes/models/archs/DiscriminatorResnet_arch.py b/codes/models/archs/DiscriminatorResnet_arch.py index 70f80929..6991b0fc 100644 --- a/codes/models/archs/DiscriminatorResnet_arch.py +++ b/codes/models/archs/DiscriminatorResnet_arch.py @@ -94,7 +94,7 @@ class FixupBottleneck(nn.Module): class FixupResNet(nn.Module): - def __init__(self, block, layers, num_filters=64, num_classes=1000): + def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64): super(FixupResNet, self).__init__() self.num_layers = sum(layers) self.inplanes = num_filters @@ -107,7 +107,8 @@ class FixupResNet(nn.Module): self.layer3 = self._make_layer(block, num_filters*4, layers[2], stride=2) self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2) self.bias2 = nn.Parameter(torch.zeros(1)) - self.fc1 = nn.Linear(num_filters * 8 * 2 * 2, 100) + reduced_img_sz = int(input_img_size / 32) + self.fc1 = nn.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100) self.fc2 = nn.Linear(100, num_classes) for m in self.modules(): diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 4025726f..2ec07ff1 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -50,13 +50,13 @@ class RRDBNet(nn.Module): super(RRDBNet, self).__init__() RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) - self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.conv_first = nn.Conv2d(in_nc, nf, 7, 1, padding=3, bias=True) self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling - self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.upconv1 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True) + self.upconv2 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True) + self.HRconv = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) diff --git a/codes/models/networks.py b/codes/models/networks.py index 995670bd..172b6387 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -58,7 +58,7 @@ def define_D(opt): if which_model == 'discriminator_vgg_128': netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128) elif which_model == 'discriminator_resnet': - netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1) + netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/options/train/finetune_ESRGAN_blacked.yml b/codes/options/train/finetune_ESRGAN_blacked.yml deleted file mode 100644 index af8f377b..00000000 --- a/codes/options/train/finetune_ESRGAN_blacked.yml +++ /dev/null @@ -1,87 +0,0 @@ -#### general settings -name: ESRGANx4_blacked_ramped_feat -use_tb_logger: true -model: srgan -distortion: sr -scale: 4 -gpu_ids: [0] -amp_opt_level: O1 - -#### datasets -datasets: - train: - name: blacked - mode: LQGT - dataroot_GT: ../datasets/blacked/train/hr - dataroot_LQ: ../datasets/lqprn/train/lr - dataroot_PIX: ../datasets/lqprn/train/hr - - use_shuffle: true - n_workers: 4 # per GPU - batch_size: 12 - target_size: 256 - use_flip: false - use_rot: false - color: RGB - val: - name: blacked_val - mode: LQGT - dataroot_GT: ../datasets/vrp/validation/hr - dataroot_LQ: ../datasets/vrp/validation/lr - -#### network structures -network_G: - which_model_G: RRDBNet - in_nc: 3 - out_nc: 3 - nf: 64 - nb: 23 -network_D: - which_model_D: discriminator_vgg_128 - in_nc: 3 - nf: 64 - -#### path -path: - pretrain_model_G: ../experiments/blacked_ft_G.pth - pretrain_model_D: ../experiments/blacked_ft_D.pth - resume_state: ~ - strict_load: true - -#### training settings: learning rate scheme, loss -train: - lr_G: !!float 1e-4 - weight_decay_G: 0 - beta1_G: 0.9 - beta2_G: 0.99 - lr_D: !!float 1e-4 - weight_decay_D: 0 - beta1_D: 0.9 - beta2_D: 0.99 - lr_scheme: MultiStepLR - - niter: 400000 - warmup_iter: -1 # no warm up - lr_steps: [10000, 20000, 30000, 40000, 50000] - lr_gamma: 0.5 - - pixel_criterion: l1 - pixel_weight: !!float 1e-2 - feature_criterion: l1 - feature_weight: !!float 1e0 - feature_weight_decay: .95 - feature_weight_decay_steps: 2500 - feature_weight_minimum: !!float 1e-3 - gan_type: ragan # gan | ragan - gan_weight: !!float 1e-1 - - D_update_ratio: 1 - D_init_iters: 0 - - manual_seed: 10 - val_freq: !!float 5e2 - -#### logger -logger: - print_freq: 50 - save_checkpoint_freq: !!float 5e2 diff --git a/codes/options/train/finetune_ESRGAN_blacked_for_adrianna.yml b/codes/options/train/finetune_ESRGAN_blacked_for_adrianna.yml deleted file mode 100644 index cc75f847..00000000 --- a/codes/options/train/finetune_ESRGAN_blacked_for_adrianna.yml +++ /dev/null @@ -1,85 +0,0 @@ -#### general settings -name: ESRGANx4_blacked_for_adrianna -use_tb_logger: true -model: srgan -distortion: sr -scale: 4 -gpu_ids: [0] -amp_opt_level: O1 - -#### datasets -datasets: - train: - name: vixen - mode: LQGT - dataroot_GT: K:\4k6k\vixen4k\hr - dataroot_LQ: E:\4k6k\mmsr\results\RRDB_ESRGAN_x4\vixen - use_shuffle: true - n_workers: 4 # per GPU - batch_size: 12 - target_size: 256 - use_flip: false - use_rot: false - color: RGB - val: - name: adrianna_val - mode: LQGT - dataroot_GT: ../datasets/adrianna/val/hr - dataroot_LQ: ../datasets/adrianna/val/lr - -#### network structures -network_G: - which_model_G: RRDBNet - in_nc: 3 - out_nc: 3 - nf: 64 - nb: 23 -network_D: - which_model_D: discriminator_vgg_128 - in_nc: 3 - nf: 64 - -#### path -path: - pretrain_model_G: ../experiments/ESRGANx4_blacked_for_adrianna/models/15500_G.pth - pretrain_model_D: ../experiments/ESRGANx4_blacked_for_adrianna/models/15500_D.pth - resume_state: ../experiments/ESRGANx4_blacked_for_adrianna/training_state/15500.state - strict_load: true - -#### training settings: learning rate scheme, loss -train: - lr_G: !!float 1e-4 - weight_decay_G: 0 - beta1_G: 0.9 - beta2_G: 0.99 - lr_D: !!float 1e-4 - weight_decay_D: 0 - beta1_D: 0.9 - beta2_D: 0.99 - lr_scheme: MultiStepLR - - niter: 400000 - warmup_iter: -1 # no warm up - lr_steps: [10000, 20000, 30000, 40000, 50000] - lr_gamma: 0.5 - - pixel_criterion: l1 - pixel_weight: !!float 1e-2 - feature_criterion: l1 - feature_weight: !!float 1e0 - feature_weight_decay: 1 - feature_weight_decay_steps: 2500 - feature_weight_minimum: !!float 1e-3 - gan_type: ragan # gan | ragan - gan_weight: !!float 1e-1 - - D_update_ratio: 1 - D_init_iters: 0 - - manual_seed: 10 - val_freq: !!float 5e2 - -#### logger -logger: - print_freq: 50 - save_checkpoint_freq: !!float 5e2 diff --git a/codes/options/train/train_ESRGAN.yml b/codes/options/train/train_ESRGAN.yml index cd6e09fb..39865580 100644 --- a/codes/options/train/train_ESRGAN.yml +++ b/codes/options/train/train_ESRGAN.yml @@ -5,18 +5,19 @@ model: srgan distortion: sr scale: 4 gpu_ids: [0] +amp_opt_level: O1 #### datasets datasets: train: name: DIV2K mode: LQGT - dataroot_GT: ../datasets/div2k/DIV2K800_sub - dataroot_LQ: ../datasets/div2k/DIV2K800_sub_bicLRx4 + dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub + dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4 use_shuffle: true n_workers: 16 # per GPU - batch_size: 16 + batch_size: 32 target_size: 128 use_flip: true use_rot: true @@ -24,8 +25,8 @@ datasets: val: name: div2kval mode: LQGT - dataroot_GT: ../datasets/div2k/div2k_valid_hr - dataroot_LQ: ../datasets/div2k/div2k_valid_lr_bicubic + dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr + dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic #### network structures network_G: @@ -35,13 +36,13 @@ network_G: nf: 64 nb: 23 network_D: - which_model_D: discriminator_vgg_128 + which_model_D: discriminator_resnet in_nc: 3 nf: 64 #### path path: - pretrain_model_G: ../experiments/RRDB_PSNR_x4.pth + pretrain_model_G: ~ strict_load: true resume_state: ~ @@ -66,10 +67,13 @@ train: pixel_weight: !!float 1e-2 feature_criterion: l1 feature_weight: 1 - gan_type: ragan # gan | ragan + feature_weight_decay: .98 + feature_weight_decay_steps: 500 + feature_weight_minimum: .1 + gan_type: gan # gan | ragan gan_weight: !!float 5e-3 - D_update_ratio: 1 + D_update_ratio: 2 D_init_iters: 0 manual_seed: 10 diff --git a/codes/options/train/finetune_ESRGAN_vrp.yml b/codes/options/train/train_ESRGAN_blacked.yml similarity index 58% rename from codes/options/train/finetune_ESRGAN_vrp.yml rename to codes/options/train/train_ESRGAN_blacked.yml index 97aaeb08..64e1af42 100644 --- a/codes/options/train/finetune_ESRGAN_vrp.yml +++ b/codes/options/train/train_ESRGAN_blacked.yml @@ -1,58 +1,56 @@ #### general settings -name: ESRGANx4_VRP +name: blacked_fix_and_upconv use_tb_logger: true model: srgan distortion: sr scale: 4 gpu_ids: [0] +amp_opt_level: O1 #### datasets datasets: train: - name: VRP + name: vixcloseup mode: LQGT - dataroot_GT: ../datasets/vrp/train/hr - dataroot_LQ: ../datasets/vrp/train/lr + dataroot_GT: K:\4k6k\4k_closeup\hr + dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted use_shuffle: true - n_workers: 0 # per GPU - batch_size: 16 - target_size: 128 - use_flip: true - use_rot: true + n_workers: 12 # per GPU + batch_size: 12 + target_size: 256 color: RGB val: - name: VRP_val + name: adrianna_val mode: LQGT - dataroot_GT: ../datasets/vrp/validation/hr - dataroot_LQ: ../datasets/vrp/validation/lr + dataroot_GT: E:\4k6k\datasets\adrianna\val\hhq + dataroot_LQ: E:\4k6k\datasets\adrianna\val\hr #### network structures network_G: which_model_G: RRDBNet in_nc: 3 out_nc: 3 - nf: 64 + nf: 48 nb: 23 network_D: - which_model_D: discriminator_vgg_128 + which_model_D: discriminator_resnet in_nc: 3 - nf: 64 + nf: 48 #### path path: - pretrain_model_G: ../experiments/div2k_gen_pretrain.pth - pretrain_model_D: ../experiments/div2k_disc_pretrain.pth + pretrain_model_G: ~ strict_load: true resume_state: ~ #### training settings: learning rate scheme, loss train: - lr_G: !!float 1e-5 + lr_G: !!float 1e-4 weight_decay_G: 0 beta1_G: 0.9 beta2_G: 0.99 - lr_D: !!float 1e-5 + lr_D: !!float 2e-4 weight_decay_D: 0 beta1_D: 0.9 beta2_D: 0.99 @@ -60,14 +58,17 @@ train: niter: 400000 warmup_iter: -1 # no warm up - lr_steps: [50000, 100000, 200000, 300000] + lr_steps: [20000, 40000, 60000, 80000] lr_gamma: 0.5 pixel_criterion: l1 pixel_weight: !!float 1e-2 feature_criterion: l1 feature_weight: 1 - gan_type: ragan # gan | ragan + feature_weight_decay: .98 + feature_weight_decay_steps: 500 + feature_weight_minimum: .1 + gan_type: gan # gan | ragan gan_weight: !!float 5e-3 D_update_ratio: 1 diff --git a/codes/options/train/train_GAN_blacked_corrupt.yml b/codes/options/train/train_GAN_blacked_corrupt.yml index 220d3375..17ade411 100644 --- a/codes/options/train/train_GAN_blacked_corrupt.yml +++ b/codes/options/train/train_GAN_blacked_corrupt.yml @@ -21,6 +21,7 @@ datasets: target_size: 64 use_flip: false use_rot: false + doCrop: false color: RGB val: name: blacked_val @@ -48,18 +49,18 @@ network_D: #### path path: - pretrain_model_G: ~ - pretrain_model_D: ~ #../experiments/resnet_corrupt_discriminator_fixup.pth - resume_state: ~ + pretrain_model_G: ../experiments/corruptGAN_4k_lqprn_closeup_flat_net/models/29000_G.pth + pretrain_model_D: ../experiments/corruptGAN_4k_lqprn_closeup_flat_net/models/29000_D.pth + resume_state: ../experiments/corruptGAN_4k_lqprn_closeup_flat_net/training_state/29000.state strict_load: true #### training settings: learning rate scheme, loss train: - lr_G: !!float 1e-4 + lr_G: !!float 5e-5 weight_decay_G: 0 beta1_G: 0.9 beta2_G: 0.99 - lr_D: !!float 2e-4 + lr_D: !!float 1e-4 weight_decay_D: 0 beta1_D: 0.9 beta2_D: 0.99 @@ -77,8 +78,8 @@ train: gan_type: gan # gan | ragan gan_weight: !!float 1e-1 - D_update_ratio: 2 - D_init_iters: 0 + D_update_ratio: 1 + D_init_iters: -1 manual_seed: 10 val_freq: !!float 5e2 diff --git a/codes/train.py b/codes/train.py index fb9b4cfa..645ac867 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_GAN_blacked_corrupt.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_blacked.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -191,6 +191,8 @@ def main(): idx = 0 for val_data in val_loader: idx += 1 + if idx >= 20: + break img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir)