srflow recipes documentation

This commit is contained in:
James Betker 2020-12-04 00:32:48 -07:00
parent 711780126e
commit a0d8f806a6
3 changed files with 319 additions and 0 deletions

97
recipes/srflow/srflow.md Normal file
View File

@ -0,0 +1,97 @@
# Working with SRFlow in DLAS
[SRFlow](https://arxiv.org/abs/2006.14200) is a normalizing-flow based SR technique that eschews GANs entirely in favor
of hooking a SR network to an invertible flow network with the objective of reducing the details of a high-resolution
image into noise indistinguishable from the Gaussian distribution. In the process of doing so, the SRFlow network
actually trains the underlying SR network to a fairly amazing degree. The end product is a network pair that is adept
at SR, restoration, and extracting high frequency outliers from HQ images.
As of November 2020, this is a new addition to this codebase. The SRFlow code was ported directly from the
[author's github](https://github.com/andreas128/SRFlow), and is very rough. I'm currently experimenting with trained
models to determine whether it is worth cleaning up.
# Training SRFlow
SRFlow is trained in 3 steps:
1. Pre-train an SR network on a L1 pixel loss. The current state of SRFlow is highly bound to the RRDB architecture
but that could be changed if desired easily enough. `train_div2k_rrdb_psnr.yml` provides a sample configuration file.
Search for `<--` in that file, make the required modifications, and run it through the trainer:
`python train.py -opt train_div2k_rrdb_psnr.yml`
The authors recommended training for 200k iterations. I found RRDB converges far sooner than this and stopped my
training around 100k iterations.
1. Train the first stage of the SRFlow network, where the RRDB network is frozen and the SRFlow layers are "warmed up".
`train_div2k_srflow.yml` can be used to do this:
`python train.py -opt train_div2k_srflow.yml`
The authors recommend training in this configuration for half of the entire SRFlow training time. Again, I find this
unnecessary. I saw that the network converges to a stable gaussian NLL on the validation set after ~20k-40k iterations,
after which I recommend moving to stage 2.
1. Train the second stage of the SRFlow network, where the RRDB network is unfrozen. Do this by editing `train_div2k_srflow.yml`
and setting `train_RRDB=true`.
After moving to this phase, you should see the gaussian NLL in the validation metrics start to decrease again. This
is a really cool phase of training, where the gradient pressure from the NLL loss is actively improving your RRDB SR
network!
# Using SRFlow
SRFlow networks have several interesting potential uses. I'll go over a few of them. I've written a script that you
might find useful for playing with trained SRFlow networks: `scripts/srflow_latent_space_playground.py`. This script does not
take arguments, you will need to modify the code directly. Just a personal preference for these types of tools.
## Super-resolution
Super resolution is performed by feeding an LR image and a latent into the network. The latent is *supposed* to be from
a gaussian distribution sized relative to the LR image, but this depends on how well the SRFlow network could adapt
itself to your image distribution. For example, I could not get the 8X SR networks to get anywhere near a gaussian; they
always "stored" much of their structural information inside of the latent.
In practice, you can get pretty damned good SR results from this network by simply feeding in zeros for the latents. This
makes the SRFlow show the "mean HQ" representation it has learned for any given LQ image. It is done by setting the
temperature input to the SRFlow network to 0. Here is an injector definition that does just that:
```
gen_inj:
type: generator
generator: generator
in: [None, lq, None, 0, True] # <-- '0' here is the temperature.
out: [gen]
```
You can also accomplish this in `srflow_latent_space_playground.py` by setting the mode to `temperature`.
## Restoration
This was touched on in the SRFlow paper. The authors recommend computing the latents of a corrupted image, then
performing normalization on it. The logic is that the SRFlow network doesn't "know" how to compute corrupted images, so
the process of normalizing the latents will cause it to output the nearest true HR image to the corrupted input image.
In practice, this works sometimes for me, sometimes not. SRFlow has a knack for producing NaNs in the reverse direction
when it encounters LR images and latent pairs that are too far out of the training distribution. This manifests as
black spots or areas of noise in the image.
In practice, what seems to work better is using the above procedure: feed your corrupted image into the SRFlow network
with a temperature of 0. This will almost always works and generally produces more pleasing results.
You can tinker with the restoration described in the paper in the `srflow_latent_space_playground.py` script by using
the `restore` mode.
## Style transfer
The SRFlow network splits high frequency information from HQ images by design. This high frequency data is encoded in
the latents. These latents can then be fed back into the network with a different LR image to accomplish a sort of
style transfer. In the paper, the authors transfer fine facial features and it seems to work well. This was hit or miss
for me, but I admittedly did not try to hard (yet).
You can tinker with latent transfer in the script by using the `latent_transfer` mode. Note that this only does whole-
image latent transfer.
# Notes on validation
My validation runs are my own design. The work by feeding a set of HQ images from your target distribution through the
SRFlow network to produce latents. These latents are then compared to a gaussian distribution and the validation score
is the per-pixel distance from that distribution. I do not compute the log of the loss since this hides fine improvements
at the log levels that this network operates in.

View File

@ -0,0 +1,94 @@
#### general settings
name: train_div2k_rrdb_psnr
use_tb_logger: true
model: extensibletrainer
distortion: sr
scale: 2
gpu_ids: [0]
fp16: false
start_step: 0
checkpointing_enabled: true # <-- Highly recommended for single-GPU training. Will not work with DDP.
wandb: false
datasets:
train:
n_workers: 4
batch_size: 32
name: div2k
mode: single_image_extensible
paths: /content/div2k # <-- Put your path here.
target_size: 128
force_multiple: 1
scale: 4
eval: False
num_corrupts_per_image: 0
strict: false
val:
name: val
mode: fullimage
dataroot_GT: /content/set14
scale: 4
force_multiple: 16
networks:
generator:
type: generator
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nb: 23
scale: 4
blocks_per_checkpoint: 3
#### path
path:
#pretrain_model_generator: <insert pretrained model path if desired>
strict_load: true
#resume_state: ../experiments/train_div2k_rrdb_psnr/training_state/0.state # <-- Set this to resume from a previous training state.
steps:
generator:
training: generator
optimizer_params:
# Optimizer params
lr: !!float 2e-4
weight_decay: 0
beta1: 0.9
beta2: 0.99
injectors:
gen_inj:
type: generator
generator: generator
in: lq
out: gen
losses:
pix:
type: pix
weight: 1
criterion: l1
real: hq
fake: gen
train:
niter: 500000
warmup_iter: -1
mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
val_freq: 2000
# Default LR scheduler options
default_lr_scheme: MultiStepLR
gen_lr_steps: [50000, 100000, 150000, 200000]
lr_gamma: 0.5
eval:
output_state: gen
logger:
print_freq: 30
save_checkpoint_freq: 1000
visuals: [gen, hq, lq]
visual_debug_rate: 100

View File

@ -0,0 +1,128 @@
#### general settings
name: train_div2k_srflow
use_tb_logger: true
model: extensibletrainer
distortion: sr
scale: 4
gpu_ids: [0]
fp16: false
start_step: -1
checkpointing_enabled: true
wandb: false
datasets:
train:
n_workers: 4
batch_size: 32
name: div2k
mode: single_image_extensible
paths: /content/div2k # <-- Put your path here.
target_size: 160 # <-- SRFlow trains better with factors of 160 for some reason.
force_multiple: 1
scale: 4
eval: False
num_corrupts_per_image: 0
strict: false
networks:
generator:
type: generator
which_model_G: srflow_orig
nf: 64
nb: 23
K: 16
scale: 4
initial_stride: 2
flow_scale: 4
train_RRDB: false # <-- Start false. After some time, ~20k-50k steps, set to true. TODO: automate this.
train_RRDB_delay: 0.5
pretrain_rrdb: ../experiments/pretrained_rrdb.pth # <-- Insert path to your pretrained RRDB here.
flow:
patch_size: 160
K: 16
L: 3
act_norm_start_step: 100
noInitialInj: true
coupling: CondAffineSeparatedAndCond
additionalFlowNoAffine: 2
split:
enable: true
fea_up0: true
fea_up-1: true
stackRRDB:
blocks: [ 1, 8, 15, 22 ]
concat: true
gaussian_loss_weight: 1
#### path
path:
#pretrain_model_generator: <insert pretrained model path if desired>
strict_load: true
#resume_state: ../experiments/train_div2k_srflow/training_state/0.state # <-- Set this to resume from a previous training state.
steps:
generator:
training: generator
optimizer_params:
# Optimizer params
lr: !!float 2e-4
weight_decay: 0
beta1: 0.9
beta2: 0.99
injectors:
z_inj:
type: generator
generator: generator
in: [hq, lq, None, None, False]
out: [z, nll]
# This is computed solely for visual_dbg - that is, to see what your model is actually doing.
gen_inj:
every: 50
type: generator
generator: generator
in: [None, lq, None, .4, True]
out: [gen]
losses:
log_likelihood:
type: direct
key: nll
weight: 1
train:
niter: 500000
warmup_iter: -1
mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
val_freq: 1000
# Default LR scheduler options
default_lr_scheme: MultiStepLR
gen_lr_steps: [20000, 40000, 80000, 100000, 140000, 180000]
lr_gamma: 0.5
eval:
evaluators:
# This is the best metric I have come up with for monitoring the training progress of srflow networks. You should
# feed this evaluator a random set of images from your target distribution.
gaussian:
for: generator
type: flownet_gaussian
batch_size: 2
dataset:
paths: /content/random_100_images
target_size: 512
force_multiple: 1
scale: 4
eval: False
num_corrupts_per_image: 0
corruption_blur_scale: 1
output_state: eval_gen
logger:
print_freq: 30
save_checkpoint_freq: 500
visuals: [gen, hq, lq]
visual_debug_rate: 50