forked from mrq/DL-Art-School
68 lines
3.4 KiB
Markdown
68 lines
3.4 KiB
Markdown
|
# DLAS Trainer
|
||
|
|
||
|
This directory contains the code for ExtensibleTrainer, which a configuration-driven generator trainer for Pytorch.
|
||
|
|
||
|
ExtensibleTrainer has three main components, **steps**, **injectors** and **losses**.
|
||
|
|
||
|
## Steps
|
||
|
|
||
|
A step is loosely associated with all of the computation needed to perform a Pytorch optimizers step() function. That
|
||
|
is:
|
||
|
|
||
|
1. Compute a forward pass.
|
||
|
1. Compute a loss.
|
||
|
1. Compute a backward pass and gather gradients.
|
||
|
|
||
|
As well as all the logging and other 'homework' associated with the above.
|
||
|
|
||
|
Since DLAS often trains GANs, it necessarily needs to support optimizing multiple networks concurrently. This is why
|
||
|
the notion of a step is broken out of the trainer: each training step can correspond to more than one optimizer steps.
|
||
|
|
||
|
Most of the logic for how a step operates can be found in `steps.py`.
|
||
|
|
||
|
## Injectors
|
||
|
|
||
|
Injectors are a way to drive a networks forward pass entirely from a configuration file. If you think of ExtensibleTrainer
|
||
|
as a state machine, injectors are the ways to mutate that state.
|
||
|
|
||
|
There are not hard rules on what an injector can do, but generally here is how it operates:
|
||
|
|
||
|
1. On startup, it is initialized with a configuration dict fed directly from the config file.
|
||
|
1. During the forward pass of a step, the injector.forward() methods are invoked sequentially, with each one getting the
|
||
|
current 'state' of the trainer.
|
||
|
1. The injector performs some computation and stores the result into the state. How these results are bound to the state
|
||
|
is generally defined within the configuration file. For example "inject the output of this generator into key 'gen'".
|
||
|
1. Losses (discussed next) feed off of the state generated by the injectors.
|
||
|
1. After the step is completed, all injected states are detached. This frees the underlying GPU memory so the next step
|
||
|
has as much memory as possible.
|
||
|
|
||
|
### Example injectors:
|
||
|
- Forward pass with a generator and storing the result in the program state.
|
||
|
- Flat values into the state (e.g. torch.zeros, torch.ones, torch.rand)
|
||
|
- Adding noise to a state variable
|
||
|
- Performing differentiable augmentations to an image tensor
|
||
|
|
||
|
See a full list of currently implemented injectors (and templates for how to add your own) in `injectors.py`.
|
||
|
|
||
|
### Rules of thumb
|
||
|
Simpler configuration files are generally better. If you need to mutate the trainer state for your model, think long
|
||
|
and hard if it would be better done in your model architecture code. It is technically feasible to implement entire
|
||
|
models with injectors, but that would result in unreadable configs. Strike the balance of configurability and
|
||
|
maintainability.
|
||
|
|
||
|
## Losses
|
||
|
|
||
|
Losses simply convert the current trainer state into a differentiable loss. Each loss must have a "weight" assigned to
|
||
|
it. The output of the loss is multiplied against this weight and all the weighted losses are summed together before
|
||
|
performing a backwards pass.
|
||
|
|
||
|
Some models directly output a loss. This is fine - you can use the `direct` loss to accomodate this.
|
||
|
|
||
|
Losses are defined in `losses.py`.
|
||
|
|
||
|
## Evaluators
|
||
|
|
||
|
As DLAS was extended past SR, it became necessary to support more complicated evaluation behaviors, e.g. FID or srflows
|
||
|
gaussian distance. To enable this, the concept of the `Evaluator` was added. Classes in the `eval` folder contain
|
||
|
various evaluator implementations. These can be fed directly into the `eval` section of your config file and will be
|
||
|
executed alongside (or instead of) your validation set.
|