DL-Art-School/dlas/trainer/README.md
2023-03-21 15:38:42 +00:00

68 lines
3.4 KiB
Markdown

# DLAS Trainer
This directory contains the code for ExtensibleTrainer, which a configuration-driven generator trainer for Pytorch.
ExtensibleTrainer has three main components, **steps**, **injectors** and **losses**.
## Steps
A step is loosely associated with all of the computation needed to perform a Pytorch optimizers step() function. That
is:
1. Compute a forward pass.
1. Compute a loss.
1. Compute a backward pass and gather gradients.
As well as all the logging and other 'homework' associated with the above.
Since DLAS often trains GANs, it necessarily needs to support optimizing multiple networks concurrently. This is why
the notion of a step is broken out of the trainer: each training step can correspond to more than one optimizer steps.
Most of the logic for how a step operates can be found in `steps.py`.
## Injectors
Injectors are a way to drive a networks forward pass entirely from a configuration file. If you think of ExtensibleTrainer
as a state machine, injectors are the ways to mutate that state.
There are not hard rules on what an injector can do, but generally here is how it operates:
1. On startup, it is initialized with a configuration dict fed directly from the config file.
1. During the forward pass of a step, the injector.forward() methods are invoked sequentially, with each one getting the
current 'state' of the trainer.
1. The injector performs some computation and stores the result into the state. How these results are bound to the state
is generally defined within the configuration file. For example "inject the output of this generator into key 'gen'".
1. Losses (discussed next) feed off of the state generated by the injectors.
1. After the step is completed, all injected states are detached. This frees the underlying GPU memory so the next step
has as much memory as possible.
### Example injectors:
- Forward pass with a generator and storing the result in the program state.
- Flat values into the state (e.g. torch.zeros, torch.ones, torch.rand)
- Adding noise to a state variable
- Performing differentiable augmentations to an image tensor
See a full list of currently implemented injectors (and templates for how to add your own) in `injectors.py`.
### Rules of thumb
Simpler configuration files are generally better. If you need to mutate the trainer state for your model, think long
and hard if it would be better done in your model architecture code. It is technically feasible to implement entire
models with injectors, but that would result in unreadable configs. Strike the balance of configurability and
maintainability.
## Losses
Losses simply convert the current trainer state into a differentiable loss. Each loss must have a "weight" assigned to
it. The output of the loss is multiplied against this weight and all the weighted losses are summed together before
performing a backwards pass.
Some models directly output a loss. This is fine - you can use the `direct` loss to accomodate this.
Losses are defined in `losses.py`.
## Evaluators
As DLAS was extended past SR, it became necessary to support more complicated evaluation behaviors, e.g. FID or srflows
gaussian distance. To enable this, the concept of the `Evaluator` was added. Classes in the `eval` folder contain
various evaluator implementations. These can be fed directly into the `eval` section of your config file and will be
executed alongside (or instead of) your validation set.