diff --git a/codes/models/archs/HighToLowResNet.py b/codes/models/archs/HighToLowResNet.py index 2989f4cc..0b7e6cb0 100644 --- a/codes/models/archs/HighToLowResNet.py +++ b/codes/models/archs/HighToLowResNet.py @@ -27,14 +27,18 @@ class HighToLowResNet(nn.Module): self.trunk_lores = arch_util.make_layer(basic_block2, nb - 15) # downsampling - if self.downscale == 4: + if self.downscale == 4 or self.downscale == 1: self.downconv1 = nn.Conv2d(nf, nf, 3, stride=2, padding=1, bias=True) self.downconv2 = nn.Conv2d(nf, nf*2, 3, stride=2, padding=1, bias=True) else: raise EnvironmentError("Requested downscale not supported: %i" % (downscale,)) self.HRconv = nn.Conv2d(nf*2, nf*2, 3, stride=1, padding=1, bias=True) - self.conv_last = nn.Conv2d(nf*2, out_nc, 3, stride=1, padding=1, bias=True) + if self.downscale == 4: + self.conv_last = nn.Conv2d(nf*2, out_nc, 3, stride=1, padding=1, bias=True) + else: + self.pixel_shuffle = nn.PixelShuffle(4) + self.conv_last = nn.Conv2d(int(nf/8), out_nc, 3, stride=1, padding=1, bias=True) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) @@ -51,13 +55,22 @@ class HighToLowResNet(nn.Module): out = self.lrelu(self.conv_first(out)) out = self.trunk_hires(out) - if self.downscale == 4: + if self.downscale == 4 or self.downscale == 1: out = self.lrelu(self.downconv1(out)) out = self.trunk_medres(out) out = self.lrelu(self.downconv2(out)) out = self.trunk_lores(out) - out = self.conv_last(self.lrelu(self.HRconv(out))) - base = F.interpolate(x, scale_factor=1/self.downscale, mode='bilinear', align_corners=False) + if self.downscale == 1: + out = self.lrelu(self.pixel_shuffle(self.HRconv(out))) + out = self.conv_last(out) + else: + out = self.conv_last(self.lrelu(self.HRconv(out))) + + if self.downscale == 1: + base = x + else: + base = F.interpolate(x, scale_factor=1/self.downscale, mode='bilinear', align_corners=False) + out += base return out diff --git a/codes/options/train/train_DownsampleGAN_blacked.yml b/codes/options/train/train_GAN_blacked_corrupt_via_transform.yml similarity index 60% rename from codes/options/train/train_DownsampleGAN_blacked.yml rename to codes/options/train/train_GAN_blacked_corrupt_via_transform.yml index 55c310b1..4a9f0da6 100644 --- a/codes/options/train/train_DownsampleGAN_blacked.yml +++ b/codes/options/train/train_GAN_blacked_corrupt_via_transform.yml @@ -1,9 +1,9 @@ #### general settings -name: downsample_GAN_blacked +name: corruptGAN_4k_lqprn_closeup_hq_to_hq use_tb_logger: true -model: srgan -distortion: sr -scale: .25 +model: corruptgan +distortion: downsample +scale: 1 gpu_ids: [0] amp_opt_level: O1 @@ -11,41 +11,41 @@ amp_opt_level: O1 datasets: train: name: blacked - mode: GTLQ - dataroot_GT: ../datasets/blacked/train/hr - dataroot_LQ: ../datasets/lqprn/train/lr - + mode: downsample + dataroot_GT: K:\\4k6k\\4k_closeup\\hr + dataroot_LQ: E:\\4k6k\\adrianna\\for_training\\hr + mismatched_Data_OK: true use_shuffle: true - n_workers: 0 # per GPU - batch_size: 7 - target_size: 64 + n_workers: 4 # per GPU + batch_size: 16 + target_size: 256 use_flip: false use_rot: false color: RGB val: name: blacked_val - mode: LQGT - dataroot_GT: ../datasets/vrp/validation/hr - dataroot_LQ: ../datasets/vrp/validation/lr + mode: downsample + target_size: 256 + dataroot_GT: ../datasets/blacked/val/hr + dataroot_LQ: ../datasets/blacked/val/hr #### network structures network_G: - which_model_G: RRDBNet + which_model_G: HighToLowResNet in_nc: 3 out_nc: 3 nf: 64 - nb: 23 + nb: 56 network_D: which_model_D: discriminator_vgg_128 in_nc: 3 - nf: 64 + nf: 128 #### path path: - #pretrain_model_G: ../experiments/blacked_gen_20000_epochs.pth - #pretrain_model_D: ../experiments/blacked_disc_20000_epochs.pth - strict_load: true + pretrain_model_G: ~ resume_state: ~ + strict_load: true #### training settings: learning rate scheme, loss train: @@ -61,16 +61,17 @@ train: niter: 400000 warmup_iter: -1 # no warm up - lr_steps: [10000, 30000, 50000, 70000] + lr_steps: [4000, 8000, 12000, 15000, 20000] lr_gamma: 0.5 pixel_criterion: l1 pixel_weight: !!float 1e-2 + feature_criterion: l1 feature_weight: 0 gan_type: ragan # gan | ragan - gan_weight: 1 + gan_weight: !!float 1e-1 - D_update_ratio: 1 + D_update_ratio: 2 D_init_iters: 0 manual_seed: 10