forked from mrq/DL-Art-School
179 lines
3.9 KiB
YAML
179 lines
3.9 KiB
YAML
name: train_div2k_esrgan
|
|
model: extensibletrainer
|
|
scale: 4
|
|
gpu_ids: [0]
|
|
fp16: false
|
|
start_step: -1
|
|
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
|
|
use_tb_logger: true
|
|
wandb: false
|
|
|
|
datasets:
|
|
train:
|
|
n_workers: 2
|
|
batch_size: 16
|
|
name: div2k
|
|
mode: single_image_extensible
|
|
paths: /content/div2k # <-- Put your path here.
|
|
target_size: 128
|
|
force_multiple: 1
|
|
scale: 4
|
|
strict: false
|
|
val:
|
|
name: val
|
|
mode: fullimage
|
|
dataroot_GT: /content/set14
|
|
scale: 4
|
|
|
|
networks:
|
|
generator:
|
|
type: generator
|
|
which_model_G: RRDBNet
|
|
in_nc: 3
|
|
out_nc: 3
|
|
initial_stride: 1
|
|
nf: 64
|
|
nb: 23
|
|
scale: 4
|
|
blocks_per_checkpoint: 3
|
|
|
|
feature_discriminator:
|
|
type: discriminator
|
|
which_model_D: discriminator_vgg_128_gn
|
|
scale: 2
|
|
nf: 64
|
|
in_nc: 3
|
|
image_size: 96
|
|
|
|
#### path
|
|
path:
|
|
#pretrain_model_generator: <insert pretrained model path if desired>
|
|
strict_load: true
|
|
#resume_state: ../experiments/train_div2k_esrgan/training_state/0.state # <-- Set this to resume from a previous training state.
|
|
|
|
steps:
|
|
|
|
feature_discriminator:
|
|
training: feature_discriminator
|
|
after: 100000 # Discriminator doesn't "turn-on" until step 100k to allow generator to anneal on PSNR loss.
|
|
|
|
# Optimizer params
|
|
lr: !!float 2e-4
|
|
weight_decay: 0
|
|
beta1: 0.9
|
|
beta2: 0.99
|
|
|
|
injectors:
|
|
# "image_patch" injectors support the translational loss below. You can remove them if you remove that loss.
|
|
plq:
|
|
type: image_patch
|
|
patch_size: 24
|
|
in: lq
|
|
out: plq
|
|
phq:
|
|
type: image_patch
|
|
patch_size: 96
|
|
in: hq
|
|
out: phq
|
|
dgen_inj:
|
|
type: generator
|
|
generator: generator
|
|
grad: false
|
|
in: plq
|
|
out: dgen
|
|
|
|
losses:
|
|
gan_disc_img:
|
|
type: discriminator_gan
|
|
gan_type: gan
|
|
weight: 1
|
|
#min_loss: .4
|
|
noise: .004
|
|
gradient_penalty: true
|
|
real: phq
|
|
fake: dgen
|
|
|
|
generator:
|
|
training: generator
|
|
|
|
optimizer_params:
|
|
lr: !!float 2e-4
|
|
weight_decay: 0
|
|
beta1: 0.9
|
|
beta2: 0.99
|
|
|
|
injectors:
|
|
pglq:
|
|
type: image_patch
|
|
patch_size: 24
|
|
in: lq
|
|
out: pglq
|
|
pghq:
|
|
type: image_patch
|
|
patch_size: 96
|
|
in: hq
|
|
out: pghq
|
|
gen_inj:
|
|
type: generator
|
|
generator: generator
|
|
in: pglq
|
|
out: gen
|
|
|
|
losses:
|
|
pix:
|
|
type: pix
|
|
weight: .05
|
|
criterion: l1
|
|
real: pghq
|
|
fake: gen
|
|
feature:
|
|
type: feature
|
|
after: 80000 # Perceptual/"feature" loss doesn't turn on until step 80k.
|
|
which_model_F: vgg
|
|
criterion: l1
|
|
weight: 1
|
|
real: pghq
|
|
fake: gen
|
|
gan_gen_img:
|
|
after: 100000
|
|
type: generator_gan
|
|
gan_type: gan
|
|
weight: .02
|
|
noise: .004
|
|
discriminator: feature_discriminator
|
|
fake: gen
|
|
real: pghq
|
|
# Translational loss <- not present in the original ESRGAN paper, but I find it reduces artifacts from the GAN.
|
|
# Feel free to remove. The network will still train well.
|
|
translational:
|
|
type: translational
|
|
after: 80000
|
|
weight: 2
|
|
criterion: l1
|
|
generator: generator
|
|
generator_output_index: 0
|
|
detach_fake: false
|
|
patch_size: 96
|
|
overlap: 64
|
|
real: gen
|
|
fake: ['pglq']
|
|
|
|
train:
|
|
niter: 500000
|
|
warmup_iter: -1
|
|
mega_batch_factor: 1
|
|
val_freq: 2000
|
|
|
|
# LR scheduler options
|
|
default_lr_scheme: MultiStepLR
|
|
gen_lr_steps: [140000, 180000, 200000, 240000] # LR is halved at these steps. Don't do it until GAN is online.
|
|
lr_gamma: 0.5
|
|
|
|
eval:
|
|
output_state: gen
|
|
|
|
logger:
|
|
print_freq: 30
|
|
save_checkpoint_freq: 1000
|
|
visuals: [gen, hq, pglq, pghq]
|
|
visual_debug_rate: 100 |