DL-Art-School/codes/trainer
2022-07-09 08:11:08 -06:00
..
custom_training_components More cleaning 2022-03-16 12:05:56 -06:00
eval Restore causal decoding 2022-07-09 08:11:08 -06:00
experiments
injectors Support causal diffusion! 2022-07-08 12:30:05 -06:00
optimizers
__init__.py
base_model.py Push training_state data to CPU memory before saving it 2022-03-04 17:57:33 -07:00
batch_size_optimizer.py
ExtensibleTrainer.py auto grad "lr" scaling 2022-07-08 00:38:25 -06:00
feature_model.py
inject.py
loss.py
losses.py remove lightweight_gan 2022-04-07 23:12:08 -07:00
lr_scheduler.py
networks.py More spring cleaning 2022-03-16 12:04:00 -06:00
README.md
steps.py mup work checkin 2022-06-09 21:15:09 -06:00

DLAS Trainer

This directory contains the code for ExtensibleTrainer, which a configuration-driven generator trainer for Pytorch.

ExtensibleTrainer has three main components, steps, injectors and losses.

Steps

A step is loosely associated with all of the computation needed to perform a Pytorch optimizers step() function. That is:

  1. Compute a forward pass.
  2. Compute a loss.
  3. Compute a backward pass and gather gradients.

As well as all the logging and other 'homework' associated with the above.

Since DLAS often trains GANs, it necessarily needs to support optimizing multiple networks concurrently. This is why the notion of a step is broken out of the trainer: each training step can correspond to more than one optimizer steps.

Most of the logic for how a step operates can be found in steps.py.

Injectors

Injectors are a way to drive a networks forward pass entirely from a configuration file. If you think of ExtensibleTrainer as a state machine, injectors are the ways to mutate that state.

There are not hard rules on what an injector can do, but generally here is how it operates:

  1. On startup, it is initialized with a configuration dict fed directly from the config file.
  2. During the forward pass of a step, the injector.forward() methods are invoked sequentially, with each one getting the current 'state' of the trainer.
  3. The injector performs some computation and stores the result into the state. How these results are bound to the state is generally defined within the configuration file. For example "inject the output of this generator into key 'gen'".
  4. Losses (discussed next) feed off of the state generated by the injectors.
  5. After the step is completed, all injected states are detached. This frees the underlying GPU memory so the next step has as much memory as possible.

Example injectors:

  • Forward pass with a generator and storing the result in the program state.
  • Flat values into the state (e.g. torch.zeros, torch.ones, torch.rand)
  • Adding noise to a state variable
  • Performing differentiable augmentations to an image tensor

See a full list of currently implemented injectors (and templates for how to add your own) in injectors.py.

Rules of thumb

Simpler configuration files are generally better. If you need to mutate the trainer state for your model, think long and hard if it would be better done in your model architecture code. It is technically feasible to implement entire models with injectors, but that would result in unreadable configs. Strike the balance of configurability and maintainability.

Losses

Losses simply convert the current trainer state into a differentiable loss. Each loss must have a "weight" assigned to it. The output of the loss is multiplied against this weight and all the weighted losses are summed together before performing a backwards pass.

Some models directly output a loss. This is fine - you can use the direct loss to accomodate this.

Losses are defined in losses.py.

Evaluators

As DLAS was extended past SR, it became necessary to support more complicated evaluation behaviors, e.g. FID or srflows gaussian distance. To enable this, the concept of the Evaluator was added. Classes in the eval folder contain various evaluator implementations. These can be fed directly into the eval section of your config file and will be executed alongside (or instead of) your validation set.