Add trainer readme
This commit is contained in:
parent
e82f4552db
commit
f35c034fa5
68
codes/trainer/README.md
Normal file
68
codes/trainer/README.md
Normal file
|
@ -0,0 +1,68 @@
|
|||
# DLAS Trainer
|
||||
|
||||
This directory contains the code for ExtensibleTrainer, which a configuration-driven generator trainer for Pytorch.
|
||||
|
||||
ExtensibleTrainer has three main components, **steps**, **injectors** and **losses**.
|
||||
|
||||
## Steps
|
||||
|
||||
A step is loosely associated with all of the computation needed to perform a Pytorch optimizers step() function. That
|
||||
is:
|
||||
|
||||
1. Compute a forward pass.
|
||||
1. Compute a loss.
|
||||
1. Compute a backward pass and gather gradients.
|
||||
|
||||
As well as all the logging and other 'homework' associated with the above.
|
||||
|
||||
Since DLAS often trains GANs, it necessarily needs to support optimizing multiple networks concurrently. This is why
|
||||
the notion of a step is broken out of the trainer: each training step can correspond to more than one optimizer steps.
|
||||
|
||||
Most of the logic for how a step operates can be found in `steps.py`.
|
||||
|
||||
## Injectors
|
||||
|
||||
Injectors are a way to drive a networks forward pass entirely from a configuration file. If you think of ExtensibleTrainer
|
||||
as a state machine, injectors are the ways to mutate that state.
|
||||
|
||||
There are not hard rules on what an injector can do, but generally here is how it operates:
|
||||
|
||||
1. On startup, it is initialized with a configuration dict fed directly from the config file.
|
||||
1. During the forward pass of a step, the injector.forward() methods are invoked sequentially, with each one getting the
|
||||
current 'state' of the trainer.
|
||||
1. The injector performs some computation and stores the result into the state. How these results are bound to the state
|
||||
is generally defined within the configuration file. For example "inject the output of this generator into key 'gen'".
|
||||
1. Losses (discussed next) feed off of the state generated by the injectors.
|
||||
1. After the step is completed, all injected states are detached. This frees the underlying GPU memory so the next step
|
||||
has as much memory as possible.
|
||||
|
||||
### Example injectors:
|
||||
- Forward pass with a generator and storing the result in the program state.
|
||||
- Flat values into the state (e.g. torch.zeros, torch.ones, torch.rand)
|
||||
- Adding noise to a state variable
|
||||
- Performing differentiable augmentations to an image tensor
|
||||
|
||||
See a full list of currently implemented injectors (and templates for how to add your own) in `injectors.py`.
|
||||
|
||||
### Rules of thumb
|
||||
Simpler configuration files are generally better. If you need to mutate the trainer state for your model, think long
|
||||
and hard if it would be better done in your model architecture code. It is technically feasible to implement entire
|
||||
models with injectors, but that would result in unreadable configs. Strike the balance of configurability and
|
||||
maintainability.
|
||||
|
||||
## Losses
|
||||
|
||||
Losses simply convert the current trainer state into a differentiable loss. Each loss must have a "weight" assigned to
|
||||
it. The output of the loss is multiplied against this weight and all the weighted losses are summed together before
|
||||
performing a backwards pass.
|
||||
|
||||
Some models directly output a loss. This is fine - you can use the `direct` loss to accomodate this.
|
||||
|
||||
Losses are defined in `losses.py`.
|
||||
|
||||
## Evaluators
|
||||
|
||||
As DLAS was extended past SR, it became necessary to support more complicated evaluation behaviors, e.g. FID or srflows
|
||||
gaussian distance. To enable this, the concept of the `Evaluator` was added. Classes in the `eval` folder contain
|
||||
various evaluator implementations. These can be fed directly into the `eval` section of your config file and will be
|
||||
executed alongside (or instead of) your validation set.
|
Loading…
Reference in New Issue
Block a user