From 71e373064f1b473d6228b81122ab8e7f35e22470 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 11 May 2024 15:02:47 -0500 Subject: [PATCH] remove redundant loss, tweak readme --- README.md | 15 +++++++++------ vall_e/engines/base.py | 7 +++---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 0063ddf..5f6e83e 100755 --- a/README.md +++ b/README.md @@ -98,6 +98,7 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t The `lr` will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`. + ### Plotting Metrics Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot yaml="./training/config.yaml"` @@ -106,6 +107,8 @@ You can specify what X and Y labels you want to plot against by passing `--xs to ### Notices +If you're training under `float16`, it is recommended to use the `local` backend with `amp` enabled. There's something really funky with `deepspeed` as a backend that's causing issues with training. + #### Training Under Windows As training under `deepspeed` and Windows is not (easily) supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend. @@ -124,17 +127,17 @@ Unfortunately, efforts to train a *good* foundational model seems entirely predi As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported: -* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards. -* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead. - - Its implementation for MoE can also be utilized. -* `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet/) with a HuggingFace-compatible RetNet model - - inferencing cost is about 0.5x, and MoE is not implemented. * `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements. * `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation. * `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer. - Setting `cfg.optimizers.bitnet=True` will make use of BitNet's linear implementation. +* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards. +* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead. + - Its implementation for MoE can also be utilized. +* `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model + - inferencing cost is about 0.5x, and MoE is not implemented. -If you're training a true foundational model, consider which backend you want to use the most. `llama` backends can benefit from all the additional tech with it, while exotic ones like `retnet` or `bitnet` can't at the moment, but may leverage experimental gains. +It's recommended to use `llama` with `xformers`-based attention, as the savings are huge in comparison to even `retnet`-backed models. ## Export diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index eeca527..6d4ccb7 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -201,7 +201,7 @@ class Engine(): def _get_grad_norm(self): t = [ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ] - self._global_grad_norm = torch.cat(t).norm().item() if len(t) else 0 + self._global_grad_norm = torch.cat(t).norm().item() if len(t) else None def get_lr(self): lrs = [] @@ -478,14 +478,13 @@ class Engines(dict[str, Engine]): flatten_dict( { name.split("-")[0]: dict( - loss=loss.item(), + **engine_stats, lr=engine.get_lr()[0], - grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation + grad_norm=engine.get_global_grad_norm(), elapsed_time=elapsed_time, engine_step=engine.global_step, samples_processed=engine.global_samples, tokens_processed=engine.tokens_processed, - **engine_stats, ) } ),