diff --git a/codes/trainer/README.md b/codes/trainer/README.md new file mode 100644 index 00000000..96c372d0 --- /dev/null +++ b/codes/trainer/README.md @@ -0,0 +1,68 @@ +# DLAS Trainer + +This directory contains the code for ExtensibleTrainer, which a configuration-driven generator trainer for Pytorch. + +ExtensibleTrainer has three main components, **steps**, **injectors** and **losses**. + +## Steps + +A step is loosely associated with all of the computation needed to perform a Pytorch optimizers step() function. That +is: + +1. Compute a forward pass. +1. Compute a loss. +1. Compute a backward pass and gather gradients. + +As well as all the logging and other 'homework' associated with the above. + +Since DLAS often trains GANs, it necessarily needs to support optimizing multiple networks concurrently. This is why +the notion of a step is broken out of the trainer: each training step can correspond to more than one optimizer steps. + +Most of the logic for how a step operates can be found in `steps.py`. + +## Injectors + +Injectors are a way to drive a networks forward pass entirely from a configuration file. If you think of ExtensibleTrainer +as a state machine, injectors are the ways to mutate that state. + +There are not hard rules on what an injector can do, but generally here is how it operates: + +1. On startup, it is initialized with a configuration dict fed directly from the config file. +1. During the forward pass of a step, the injector.forward() methods are invoked sequentially, with each one getting the + current 'state' of the trainer. +1. The injector performs some computation and stores the result into the state. How these results are bound to the state + is generally defined within the configuration file. For example "inject the output of this generator into key 'gen'". +1. Losses (discussed next) feed off of the state generated by the injectors. +1. After the step is completed, all injected states are detached. This frees the underlying GPU memory so the next step + has as much memory as possible. + +### Example injectors: +- Forward pass with a generator and storing the result in the program state. +- Flat values into the state (e.g. torch.zeros, torch.ones, torch.rand) +- Adding noise to a state variable +- Performing differentiable augmentations to an image tensor + +See a full list of currently implemented injectors (and templates for how to add your own) in `injectors.py`. + +### Rules of thumb +Simpler configuration files are generally better. If you need to mutate the trainer state for your model, think long +and hard if it would be better done in your model architecture code. It is technically feasible to implement entire +models with injectors, but that would result in unreadable configs. Strike the balance of configurability and +maintainability. + +## Losses + +Losses simply convert the current trainer state into a differentiable loss. Each loss must have a "weight" assigned to +it. The output of the loss is multiplied against this weight and all the weighted losses are summed together before +performing a backwards pass. + +Some models directly output a loss. This is fine - you can use the `direct` loss to accomodate this. + +Losses are defined in `losses.py`. + +## Evaluators + +As DLAS was extended past SR, it became necessary to support more complicated evaluation behaviors, e.g. FID or srflows +gaussian distance. To enable this, the concept of the `Evaluator` was added. Classes in the `eval` folder contain +various evaluator implementations. These can be fed directly into the `eval` section of your config file and will be +executed alongside (or instead of) your validation set. \ No newline at end of file