Fixup upconv for the next attempt!

This commit is contained in:
James Betker 2020-05-01 19:56:14 -06:00
parent 7eaabce48d
commit 9e1acfe396
9 changed files with 54 additions and 217 deletions

View File

@ -94,7 +94,7 @@ class FixupBottleneck(nn.Module):
class FixupResNet(nn.Module): class FixupResNet(nn.Module):
def __init__(self, block, layers, num_filters=64, num_classes=1000): def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64):
super(FixupResNet, self).__init__() super(FixupResNet, self).__init__()
self.num_layers = sum(layers) self.num_layers = sum(layers)
self.inplanes = num_filters self.inplanes = num_filters
@ -107,7 +107,8 @@ class FixupResNet(nn.Module):
self.layer3 = self._make_layer(block, num_filters*4, layers[2], stride=2) self.layer3 = self._make_layer(block, num_filters*4, layers[2], stride=2)
self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2) self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2)
self.bias2 = nn.Parameter(torch.zeros(1)) self.bias2 = nn.Parameter(torch.zeros(1))
self.fc1 = nn.Linear(num_filters * 8 * 2 * 2, 100) reduced_img_sz = int(input_img_size / 32)
self.fc1 = nn.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
self.fc2 = nn.Linear(100, num_classes) self.fc2 = nn.Linear(100, num_classes)
for m in self.modules(): for m in self.modules():

View File

@ -50,13 +50,13 @@ class RRDBNet(nn.Module):
super(RRDBNet, self).__init__() super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv_first = nn.Conv2d(in_nc, nf, 7, 1, padding=3, bias=True)
self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb) self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling #### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.upconv1 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.HRconv = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

View File

@ -58,7 +58,7 @@ def define_D(opt):
if which_model == 'discriminator_vgg_128': if which_model == 'discriminator_vgg_128':
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128) netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == 'discriminator_resnet': elif which_model == 'discriminator_resnet':
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1) netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
else: else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD return netD

View File

@ -1,87 +0,0 @@
#### general settings
name: ESRGANx4_blacked_ramped_feat
use_tb_logger: true
model: srgan
distortion: sr
scale: 4
gpu_ids: [0]
amp_opt_level: O1
#### datasets
datasets:
train:
name: blacked
mode: LQGT
dataroot_GT: ../datasets/blacked/train/hr
dataroot_LQ: ../datasets/lqprn/train/lr
dataroot_PIX: ../datasets/lqprn/train/hr
use_shuffle: true
n_workers: 4 # per GPU
batch_size: 12
target_size: 256
use_flip: false
use_rot: false
color: RGB
val:
name: blacked_val
mode: LQGT
dataroot_GT: ../datasets/vrp/validation/hr
dataroot_LQ: ../datasets/vrp/validation/lr
#### network structures
network_G:
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nb: 23
network_D:
which_model_D: discriminator_vgg_128
in_nc: 3
nf: 64
#### path
path:
pretrain_model_G: ../experiments/blacked_ft_G.pth
pretrain_model_D: ../experiments/blacked_ft_D.pth
resume_state: ~
strict_load: true
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 1e-4
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 1e-4
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
lr_scheme: MultiStepLR
niter: 400000
warmup_iter: -1 # no warm up
lr_steps: [10000, 20000, 30000, 40000, 50000]
lr_gamma: 0.5
pixel_criterion: l1
pixel_weight: !!float 1e-2
feature_criterion: l1
feature_weight: !!float 1e0
feature_weight_decay: .95
feature_weight_decay_steps: 2500
feature_weight_minimum: !!float 1e-3
gan_type: ragan # gan | ragan
gan_weight: !!float 1e-1
D_update_ratio: 1
D_init_iters: 0
manual_seed: 10
val_freq: !!float 5e2
#### logger
logger:
print_freq: 50
save_checkpoint_freq: !!float 5e2

View File

@ -1,85 +0,0 @@
#### general settings
name: ESRGANx4_blacked_for_adrianna
use_tb_logger: true
model: srgan
distortion: sr
scale: 4
gpu_ids: [0]
amp_opt_level: O1
#### datasets
datasets:
train:
name: vixen
mode: LQGT
dataroot_GT: K:\4k6k\vixen4k\hr
dataroot_LQ: E:\4k6k\mmsr\results\RRDB_ESRGAN_x4\vixen
use_shuffle: true
n_workers: 4 # per GPU
batch_size: 12
target_size: 256
use_flip: false
use_rot: false
color: RGB
val:
name: adrianna_val
mode: LQGT
dataroot_GT: ../datasets/adrianna/val/hr
dataroot_LQ: ../datasets/adrianna/val/lr
#### network structures
network_G:
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nb: 23
network_D:
which_model_D: discriminator_vgg_128
in_nc: 3
nf: 64
#### path
path:
pretrain_model_G: ../experiments/ESRGANx4_blacked_for_adrianna/models/15500_G.pth
pretrain_model_D: ../experiments/ESRGANx4_blacked_for_adrianna/models/15500_D.pth
resume_state: ../experiments/ESRGANx4_blacked_for_adrianna/training_state/15500.state
strict_load: true
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 1e-4
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 1e-4
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
lr_scheme: MultiStepLR
niter: 400000
warmup_iter: -1 # no warm up
lr_steps: [10000, 20000, 30000, 40000, 50000]
lr_gamma: 0.5
pixel_criterion: l1
pixel_weight: !!float 1e-2
feature_criterion: l1
feature_weight: !!float 1e0
feature_weight_decay: 1
feature_weight_decay_steps: 2500
feature_weight_minimum: !!float 1e-3
gan_type: ragan # gan | ragan
gan_weight: !!float 1e-1
D_update_ratio: 1
D_init_iters: 0
manual_seed: 10
val_freq: !!float 5e2
#### logger
logger:
print_freq: 50
save_checkpoint_freq: !!float 5e2

View File

@ -5,18 +5,19 @@ model: srgan
distortion: sr distortion: sr
scale: 4 scale: 4
gpu_ids: [0] gpu_ids: [0]
amp_opt_level: O1
#### datasets #### datasets
datasets: datasets:
train: train:
name: DIV2K name: DIV2K
mode: LQGT mode: LQGT
dataroot_GT: ../datasets/div2k/DIV2K800_sub dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub
dataroot_LQ: ../datasets/div2k/DIV2K800_sub_bicLRx4 dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4
use_shuffle: true use_shuffle: true
n_workers: 16 # per GPU n_workers: 16 # per GPU
batch_size: 16 batch_size: 32
target_size: 128 target_size: 128
use_flip: true use_flip: true
use_rot: true use_rot: true
@ -24,8 +25,8 @@ datasets:
val: val:
name: div2kval name: div2kval
mode: LQGT mode: LQGT
dataroot_GT: ../datasets/div2k/div2k_valid_hr dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr
dataroot_LQ: ../datasets/div2k/div2k_valid_lr_bicubic dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic
#### network structures #### network structures
network_G: network_G:
@ -35,13 +36,13 @@ network_G:
nf: 64 nf: 64
nb: 23 nb: 23
network_D: network_D:
which_model_D: discriminator_vgg_128 which_model_D: discriminator_resnet
in_nc: 3 in_nc: 3
nf: 64 nf: 64
#### path #### path
path: path:
pretrain_model_G: ../experiments/RRDB_PSNR_x4.pth pretrain_model_G: ~
strict_load: true strict_load: true
resume_state: ~ resume_state: ~
@ -66,10 +67,13 @@ train:
pixel_weight: !!float 1e-2 pixel_weight: !!float 1e-2
feature_criterion: l1 feature_criterion: l1
feature_weight: 1 feature_weight: 1
gan_type: ragan # gan | ragan feature_weight_decay: .98
feature_weight_decay_steps: 500
feature_weight_minimum: .1
gan_type: gan # gan | ragan
gan_weight: !!float 5e-3 gan_weight: !!float 5e-3
D_update_ratio: 1 D_update_ratio: 2
D_init_iters: 0 D_init_iters: 0
manual_seed: 10 manual_seed: 10

View File

@ -1,58 +1,56 @@
#### general settings #### general settings
name: ESRGANx4_VRP name: blacked_fix_and_upconv
use_tb_logger: true use_tb_logger: true
model: srgan model: srgan
distortion: sr distortion: sr
scale: 4 scale: 4
gpu_ids: [0] gpu_ids: [0]
amp_opt_level: O1
#### datasets #### datasets
datasets: datasets:
train: train:
name: VRP name: vixcloseup
mode: LQGT mode: LQGT
dataroot_GT: ../datasets/vrp/train/hr dataroot_GT: K:\4k6k\4k_closeup\hr
dataroot_LQ: ../datasets/vrp/train/lr dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted
use_shuffle: true use_shuffle: true
n_workers: 0 # per GPU n_workers: 12 # per GPU
batch_size: 16 batch_size: 12
target_size: 128 target_size: 256
use_flip: true
use_rot: true
color: RGB color: RGB
val: val:
name: VRP_val name: adrianna_val
mode: LQGT mode: LQGT
dataroot_GT: ../datasets/vrp/validation/hr dataroot_GT: E:\4k6k\datasets\adrianna\val\hhq
dataroot_LQ: ../datasets/vrp/validation/lr dataroot_LQ: E:\4k6k\datasets\adrianna\val\hr
#### network structures #### network structures
network_G: network_G:
which_model_G: RRDBNet which_model_G: RRDBNet
in_nc: 3 in_nc: 3
out_nc: 3 out_nc: 3
nf: 64 nf: 48
nb: 23 nb: 23
network_D: network_D:
which_model_D: discriminator_vgg_128 which_model_D: discriminator_resnet
in_nc: 3 in_nc: 3
nf: 64 nf: 48
#### path #### path
path: path:
pretrain_model_G: ../experiments/div2k_gen_pretrain.pth pretrain_model_G: ~
pretrain_model_D: ../experiments/div2k_disc_pretrain.pth
strict_load: true strict_load: true
resume_state: ~ resume_state: ~
#### training settings: learning rate scheme, loss #### training settings: learning rate scheme, loss
train: train:
lr_G: !!float 1e-5 lr_G: !!float 1e-4
weight_decay_G: 0 weight_decay_G: 0
beta1_G: 0.9 beta1_G: 0.9
beta2_G: 0.99 beta2_G: 0.99
lr_D: !!float 1e-5 lr_D: !!float 2e-4
weight_decay_D: 0 weight_decay_D: 0
beta1_D: 0.9 beta1_D: 0.9
beta2_D: 0.99 beta2_D: 0.99
@ -60,14 +58,17 @@ train:
niter: 400000 niter: 400000
warmup_iter: -1 # no warm up warmup_iter: -1 # no warm up
lr_steps: [50000, 100000, 200000, 300000] lr_steps: [20000, 40000, 60000, 80000]
lr_gamma: 0.5 lr_gamma: 0.5
pixel_criterion: l1 pixel_criterion: l1
pixel_weight: !!float 1e-2 pixel_weight: !!float 1e-2
feature_criterion: l1 feature_criterion: l1
feature_weight: 1 feature_weight: 1
gan_type: ragan # gan | ragan feature_weight_decay: .98
feature_weight_decay_steps: 500
feature_weight_minimum: .1
gan_type: gan # gan | ragan
gan_weight: !!float 5e-3 gan_weight: !!float 5e-3
D_update_ratio: 1 D_update_ratio: 1

View File

@ -21,6 +21,7 @@ datasets:
target_size: 64 target_size: 64
use_flip: false use_flip: false
use_rot: false use_rot: false
doCrop: false
color: RGB color: RGB
val: val:
name: blacked_val name: blacked_val
@ -48,18 +49,18 @@ network_D:
#### path #### path
path: path:
pretrain_model_G: ~ pretrain_model_G: ../experiments/corruptGAN_4k_lqprn_closeup_flat_net/models/29000_G.pth
pretrain_model_D: ~ #../experiments/resnet_corrupt_discriminator_fixup.pth pretrain_model_D: ../experiments/corruptGAN_4k_lqprn_closeup_flat_net/models/29000_D.pth
resume_state: ~ resume_state: ../experiments/corruptGAN_4k_lqprn_closeup_flat_net/training_state/29000.state
strict_load: true strict_load: true
#### training settings: learning rate scheme, loss #### training settings: learning rate scheme, loss
train: train:
lr_G: !!float 1e-4 lr_G: !!float 5e-5
weight_decay_G: 0 weight_decay_G: 0
beta1_G: 0.9 beta1_G: 0.9
beta2_G: 0.99 beta2_G: 0.99
lr_D: !!float 2e-4 lr_D: !!float 1e-4
weight_decay_D: 0 weight_decay_D: 0
beta1_D: 0.9 beta1_D: 0.9
beta2_D: 0.99 beta2_D: 0.99
@ -77,8 +78,8 @@ train:
gan_type: gan # gan | ragan gan_type: gan # gan | ragan
gan_weight: !!float 1e-1 gan_weight: !!float 1e-1
D_update_ratio: 2 D_update_ratio: 1
D_init_iters: 0 D_init_iters: -1
manual_seed: 10 manual_seed: 10
val_freq: !!float 5e2 val_freq: !!float 5e2

View File

@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_GAN_blacked_corrupt.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_blacked.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
@ -191,6 +191,8 @@ def main():
idx = 0 idx = 0
for val_data in val_loader: for val_data in val_loader:
idx += 1 idx += 1
if idx >= 20:
break
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
img_dir = os.path.join(opt['path']['val_images'], img_name) img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir) util.mkdir(img_dir)