ee8ceed6da
- use a gated activation layer for both attention & convs - add a relativistic learned position bias. I believe this is similar to the T5 position encodings but it is simpler and learned - get rid of prepending to the attention matrix - this doesn't really work that well. the model eventually learns to attend one of its heads to these blocks but why not just concat if it is doing that? |
||
---|---|---|
.. | ||
custom_training_components | ||
eval | ||
experiments | ||
injectors | ||
optimizers | ||
__init__.py | ||
base_model.py | ||
batch_size_optimizer.py | ||
ExtensibleTrainer.py | ||
feature_model.py | ||
inject.py | ||
loss.py | ||
losses.py | ||
lr_scheduler.py | ||
networks.py | ||
README.md | ||
steps.py |
DLAS Trainer
This directory contains the code for ExtensibleTrainer, which a configuration-driven generator trainer for Pytorch.
ExtensibleTrainer has three main components, steps, injectors and losses.
Steps
A step is loosely associated with all of the computation needed to perform a Pytorch optimizers step() function. That is:
- Compute a forward pass.
- Compute a loss.
- Compute a backward pass and gather gradients.
As well as all the logging and other 'homework' associated with the above.
Since DLAS often trains GANs, it necessarily needs to support optimizing multiple networks concurrently. This is why the notion of a step is broken out of the trainer: each training step can correspond to more than one optimizer steps.
Most of the logic for how a step operates can be found in steps.py
.
Injectors
Injectors are a way to drive a networks forward pass entirely from a configuration file. If you think of ExtensibleTrainer as a state machine, injectors are the ways to mutate that state.
There are not hard rules on what an injector can do, but generally here is how it operates:
- On startup, it is initialized with a configuration dict fed directly from the config file.
- During the forward pass of a step, the injector.forward() methods are invoked sequentially, with each one getting the current 'state' of the trainer.
- The injector performs some computation and stores the result into the state. How these results are bound to the state is generally defined within the configuration file. For example "inject the output of this generator into key 'gen'".
- Losses (discussed next) feed off of the state generated by the injectors.
- After the step is completed, all injected states are detached. This frees the underlying GPU memory so the next step has as much memory as possible.
Example injectors:
- Forward pass with a generator and storing the result in the program state.
- Flat values into the state (e.g. torch.zeros, torch.ones, torch.rand)
- Adding noise to a state variable
- Performing differentiable augmentations to an image tensor
See a full list of currently implemented injectors (and templates for how to add your own) in injectors.py
.
Rules of thumb
Simpler configuration files are generally better. If you need to mutate the trainer state for your model, think long and hard if it would be better done in your model architecture code. It is technically feasible to implement entire models with injectors, but that would result in unreadable configs. Strike the balance of configurability and maintainability.
Losses
Losses simply convert the current trainer state into a differentiable loss. Each loss must have a "weight" assigned to it. The output of the loss is multiplied against this weight and all the weighted losses are summed together before performing a backwards pass.
Some models directly output a loss. This is fine - you can use the direct
loss to accomodate this.
Losses are defined in losses.py
.
Evaluators
As DLAS was extended past SR, it became necessary to support more complicated evaluation behaviors, e.g. FID or srflows
gaussian distance. To enable this, the concept of the Evaluator
was added. Classes in the eval
folder contain
various evaluator implementations. These can be fed directly into the eval
section of your config file and will be
executed alongside (or instead of) your validation set.