srflow recipes documentation
This commit is contained in:
parent
711780126e
commit
a0d8f806a6
97
recipes/srflow/srflow.md
Normal file
97
recipes/srflow/srflow.md
Normal file
|
@ -0,0 +1,97 @@
|
|||
# Working with SRFlow in DLAS
|
||||
|
||||
[SRFlow](https://arxiv.org/abs/2006.14200) is a normalizing-flow based SR technique that eschews GANs entirely in favor
|
||||
of hooking a SR network to an invertible flow network with the objective of reducing the details of a high-resolution
|
||||
image into noise indistinguishable from the Gaussian distribution. In the process of doing so, the SRFlow network
|
||||
actually trains the underlying SR network to a fairly amazing degree. The end product is a network pair that is adept
|
||||
at SR, restoration, and extracting high frequency outliers from HQ images.
|
||||
|
||||
As of November 2020, this is a new addition to this codebase. The SRFlow code was ported directly from the
|
||||
[author's github](https://github.com/andreas128/SRFlow), and is very rough. I'm currently experimenting with trained
|
||||
models to determine whether it is worth cleaning up.
|
||||
|
||||
# Training SRFlow
|
||||
|
||||
SRFlow is trained in 3 steps:
|
||||
|
||||
1. Pre-train an SR network on a L1 pixel loss. The current state of SRFlow is highly bound to the RRDB architecture
|
||||
but that could be changed if desired easily enough. `train_div2k_rrdb_psnr.yml` provides a sample configuration file.
|
||||
Search for `<--` in that file, make the required modifications, and run it through the trainer:
|
||||
|
||||
`python train.py -opt train_div2k_rrdb_psnr.yml`
|
||||
|
||||
The authors recommended training for 200k iterations. I found RRDB converges far sooner than this and stopped my
|
||||
training around 100k iterations.
|
||||
1. Train the first stage of the SRFlow network, where the RRDB network is frozen and the SRFlow layers are "warmed up".
|
||||
`train_div2k_srflow.yml` can be used to do this:
|
||||
|
||||
`python train.py -opt train_div2k_srflow.yml`
|
||||
|
||||
The authors recommend training in this configuration for half of the entire SRFlow training time. Again, I find this
|
||||
unnecessary. I saw that the network converges to a stable gaussian NLL on the validation set after ~20k-40k iterations,
|
||||
after which I recommend moving to stage 2.
|
||||
1. Train the second stage of the SRFlow network, where the RRDB network is unfrozen. Do this by editing `train_div2k_srflow.yml`
|
||||
and setting `train_RRDB=true`.
|
||||
|
||||
After moving to this phase, you should see the gaussian NLL in the validation metrics start to decrease again. This
|
||||
is a really cool phase of training, where the gradient pressure from the NLL loss is actively improving your RRDB SR
|
||||
network!
|
||||
|
||||
# Using SRFlow
|
||||
|
||||
SRFlow networks have several interesting potential uses. I'll go over a few of them. I've written a script that you
|
||||
might find useful for playing with trained SRFlow networks: `scripts/srflow_latent_space_playground.py`. This script does not
|
||||
take arguments, you will need to modify the code directly. Just a personal preference for these types of tools.
|
||||
|
||||
## Super-resolution
|
||||
|
||||
Super resolution is performed by feeding an LR image and a latent into the network. The latent is *supposed* to be from
|
||||
a gaussian distribution sized relative to the LR image, but this depends on how well the SRFlow network could adapt
|
||||
itself to your image distribution. For example, I could not get the 8X SR networks to get anywhere near a gaussian; they
|
||||
always "stored" much of their structural information inside of the latent.
|
||||
|
||||
In practice, you can get pretty damned good SR results from this network by simply feeding in zeros for the latents. This
|
||||
makes the SRFlow show the "mean HQ" representation it has learned for any given LQ image. It is done by setting the
|
||||
temperature input to the SRFlow network to 0. Here is an injector definition that does just that:
|
||||
```
|
||||
gen_inj:
|
||||
type: generator
|
||||
generator: generator
|
||||
in: [None, lq, None, 0, True] # <-- '0' here is the temperature.
|
||||
out: [gen]
|
||||
```
|
||||
|
||||
You can also accomplish this in `srflow_latent_space_playground.py` by setting the mode to `temperature`.
|
||||
|
||||
## Restoration
|
||||
|
||||
This was touched on in the SRFlow paper. The authors recommend computing the latents of a corrupted image, then
|
||||
performing normalization on it. The logic is that the SRFlow network doesn't "know" how to compute corrupted images, so
|
||||
the process of normalizing the latents will cause it to output the nearest true HR image to the corrupted input image.
|
||||
|
||||
In practice, this works sometimes for me, sometimes not. SRFlow has a knack for producing NaNs in the reverse direction
|
||||
when it encounters LR images and latent pairs that are too far out of the training distribution. This manifests as
|
||||
black spots or areas of noise in the image.
|
||||
|
||||
In practice, what seems to work better is using the above procedure: feed your corrupted image into the SRFlow network
|
||||
with a temperature of 0. This will almost always works and generally produces more pleasing results.
|
||||
|
||||
You can tinker with the restoration described in the paper in the `srflow_latent_space_playground.py` script by using
|
||||
the `restore` mode.
|
||||
|
||||
## Style transfer
|
||||
|
||||
The SRFlow network splits high frequency information from HQ images by design. This high frequency data is encoded in
|
||||
the latents. These latents can then be fed back into the network with a different LR image to accomplish a sort of
|
||||
style transfer. In the paper, the authors transfer fine facial features and it seems to work well. This was hit or miss
|
||||
for me, but I admittedly did not try to hard (yet).
|
||||
|
||||
You can tinker with latent transfer in the script by using the `latent_transfer` mode. Note that this only does whole-
|
||||
image latent transfer.
|
||||
|
||||
# Notes on validation
|
||||
|
||||
My validation runs are my own design. The work by feeding a set of HQ images from your target distribution through the
|
||||
SRFlow network to produce latents. These latents are then compared to a gaussian distribution and the validation score
|
||||
is the per-pixel distance from that distribution. I do not compute the log of the loss since this hides fine improvements
|
||||
at the log levels that this network operates in.
|
94
recipes/srflow/train_div2k_rrdb_psnr.yml
Normal file
94
recipes/srflow/train_div2k_rrdb_psnr.yml
Normal file
|
@ -0,0 +1,94 @@
|
|||
#### general settings
|
||||
name: train_div2k_rrdb_psnr
|
||||
use_tb_logger: true
|
||||
model: extensibletrainer
|
||||
distortion: sr
|
||||
scale: 2
|
||||
gpu_ids: [0]
|
||||
fp16: false
|
||||
start_step: 0
|
||||
checkpointing_enabled: true # <-- Highly recommended for single-GPU training. Will not work with DDP.
|
||||
wandb: false
|
||||
|
||||
datasets:
|
||||
train:
|
||||
n_workers: 4
|
||||
batch_size: 32
|
||||
name: div2k
|
||||
mode: single_image_extensible
|
||||
paths: /content/div2k # <-- Put your path here.
|
||||
target_size: 128
|
||||
force_multiple: 1
|
||||
scale: 4
|
||||
eval: False
|
||||
num_corrupts_per_image: 0
|
||||
strict: false
|
||||
val:
|
||||
name: val
|
||||
mode: fullimage
|
||||
dataroot_GT: /content/set14
|
||||
scale: 4
|
||||
force_multiple: 16
|
||||
|
||||
networks:
|
||||
generator:
|
||||
type: generator
|
||||
which_model_G: RRDBNet
|
||||
in_nc: 3
|
||||
out_nc: 3
|
||||
nf: 64
|
||||
nb: 23
|
||||
scale: 4
|
||||
blocks_per_checkpoint: 3
|
||||
|
||||
#### path
|
||||
path:
|
||||
#pretrain_model_generator: <insert pretrained model path if desired>
|
||||
strict_load: true
|
||||
#resume_state: ../experiments/train_div2k_rrdb_psnr/training_state/0.state # <-- Set this to resume from a previous training state.
|
||||
|
||||
steps:
|
||||
generator:
|
||||
training: generator
|
||||
|
||||
optimizer_params:
|
||||
# Optimizer params
|
||||
lr: !!float 2e-4
|
||||
weight_decay: 0
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
|
||||
injectors:
|
||||
gen_inj:
|
||||
type: generator
|
||||
generator: generator
|
||||
in: lq
|
||||
out: gen
|
||||
|
||||
losses:
|
||||
pix:
|
||||
type: pix
|
||||
weight: 1
|
||||
criterion: l1
|
||||
real: hq
|
||||
fake: gen
|
||||
|
||||
train:
|
||||
niter: 500000
|
||||
warmup_iter: -1
|
||||
mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
|
||||
val_freq: 2000
|
||||
|
||||
# Default LR scheduler options
|
||||
default_lr_scheme: MultiStepLR
|
||||
gen_lr_steps: [50000, 100000, 150000, 200000]
|
||||
lr_gamma: 0.5
|
||||
|
||||
eval:
|
||||
output_state: gen
|
||||
|
||||
logger:
|
||||
print_freq: 30
|
||||
save_checkpoint_freq: 1000
|
||||
visuals: [gen, hq, lq]
|
||||
visual_debug_rate: 100
|
128
recipes/srflow/train_div2k_srflow.yml
Normal file
128
recipes/srflow/train_div2k_srflow.yml
Normal file
|
@ -0,0 +1,128 @@
|
|||
#### general settings
|
||||
name: train_div2k_srflow
|
||||
use_tb_logger: true
|
||||
model: extensibletrainer
|
||||
distortion: sr
|
||||
scale: 4
|
||||
gpu_ids: [0]
|
||||
fp16: false
|
||||
start_step: -1
|
||||
checkpointing_enabled: true
|
||||
wandb: false
|
||||
|
||||
datasets:
|
||||
train:
|
||||
n_workers: 4
|
||||
batch_size: 32
|
||||
name: div2k
|
||||
mode: single_image_extensible
|
||||
paths: /content/div2k # <-- Put your path here.
|
||||
target_size: 160 # <-- SRFlow trains better with factors of 160 for some reason.
|
||||
force_multiple: 1
|
||||
scale: 4
|
||||
eval: False
|
||||
num_corrupts_per_image: 0
|
||||
strict: false
|
||||
|
||||
networks:
|
||||
generator:
|
||||
type: generator
|
||||
which_model_G: srflow_orig
|
||||
nf: 64
|
||||
nb: 23
|
||||
K: 16
|
||||
scale: 4
|
||||
initial_stride: 2
|
||||
flow_scale: 4
|
||||
train_RRDB: false # <-- Start false. After some time, ~20k-50k steps, set to true. TODO: automate this.
|
||||
train_RRDB_delay: 0.5
|
||||
pretrain_rrdb: ../experiments/pretrained_rrdb.pth # <-- Insert path to your pretrained RRDB here.
|
||||
|
||||
flow:
|
||||
patch_size: 160
|
||||
K: 16
|
||||
L: 3
|
||||
act_norm_start_step: 100
|
||||
noInitialInj: true
|
||||
coupling: CondAffineSeparatedAndCond
|
||||
additionalFlowNoAffine: 2
|
||||
split:
|
||||
enable: true
|
||||
fea_up0: true
|
||||
fea_up-1: true
|
||||
stackRRDB:
|
||||
blocks: [ 1, 8, 15, 22 ]
|
||||
concat: true
|
||||
gaussian_loss_weight: 1
|
||||
|
||||
#### path
|
||||
path:
|
||||
#pretrain_model_generator: <insert pretrained model path if desired>
|
||||
strict_load: true
|
||||
#resume_state: ../experiments/train_div2k_srflow/training_state/0.state # <-- Set this to resume from a previous training state.
|
||||
|
||||
steps:
|
||||
generator:
|
||||
training: generator
|
||||
|
||||
optimizer_params:
|
||||
# Optimizer params
|
||||
lr: !!float 2e-4
|
||||
weight_decay: 0
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
|
||||
injectors:
|
||||
z_inj:
|
||||
type: generator
|
||||
generator: generator
|
||||
in: [hq, lq, None, None, False]
|
||||
out: [z, nll]
|
||||
# This is computed solely for visual_dbg - that is, to see what your model is actually doing.
|
||||
gen_inj:
|
||||
every: 50
|
||||
type: generator
|
||||
generator: generator
|
||||
in: [None, lq, None, .4, True]
|
||||
out: [gen]
|
||||
|
||||
losses:
|
||||
log_likelihood:
|
||||
type: direct
|
||||
key: nll
|
||||
weight: 1
|
||||
|
||||
train:
|
||||
niter: 500000
|
||||
warmup_iter: -1
|
||||
mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
|
||||
val_freq: 1000
|
||||
|
||||
# Default LR scheduler options
|
||||
default_lr_scheme: MultiStepLR
|
||||
gen_lr_steps: [20000, 40000, 80000, 100000, 140000, 180000]
|
||||
lr_gamma: 0.5
|
||||
|
||||
eval:
|
||||
evaluators:
|
||||
# This is the best metric I have come up with for monitoring the training progress of srflow networks. You should
|
||||
# feed this evaluator a random set of images from your target distribution.
|
||||
gaussian:
|
||||
for: generator
|
||||
type: flownet_gaussian
|
||||
batch_size: 2
|
||||
dataset:
|
||||
paths: /content/random_100_images
|
||||
target_size: 512
|
||||
force_multiple: 1
|
||||
scale: 4
|
||||
eval: False
|
||||
num_corrupts_per_image: 0
|
||||
corruption_blur_scale: 1
|
||||
output_state: eval_gen
|
||||
|
||||
logger:
|
||||
print_freq: 30
|
||||
save_checkpoint_freq: 500
|
||||
visuals: [gen, hq, lq]
|
||||
visual_debug_rate: 50
|
Loading…
Reference in New Issue
Block a user