remove redundant loss, tweak readme
This commit is contained in:
parent
04a80d6b55
commit
71e373064f
15
README.md
15
README.md
|
@ -98,6 +98,7 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t
|
|||
|
||||
The `lr` will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
|
||||
|
||||
|
||||
### Plotting Metrics
|
||||
|
||||
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot yaml="./training/config.yaml"`
|
||||
|
@ -106,6 +107,8 @@ You can specify what X and Y labels you want to plot against by passing `--xs to
|
|||
|
||||
### Notices
|
||||
|
||||
If you're training under `float16`, it is recommended to use the `local` backend with `amp` enabled. There's something really funky with `deepspeed` as a backend that's causing issues with training.
|
||||
|
||||
#### Training Under Windows
|
||||
|
||||
As training under `deepspeed` and Windows is not (easily) supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend.
|
||||
|
@ -124,17 +127,17 @@ Unfortunately, efforts to train a *good* foundational model seems entirely predi
|
|||
|
||||
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported:
|
||||
|
||||
* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards.
|
||||
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
|
||||
- Its implementation for MoE can also be utilized.
|
||||
* `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet/) with a HuggingFace-compatible RetNet model
|
||||
- inferencing cost is about 0.5x, and MoE is not implemented.
|
||||
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
|
||||
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
|
||||
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
|
||||
- Setting `cfg.optimizers.bitnet=True` will make use of BitNet's linear implementation.
|
||||
* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards.
|
||||
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
|
||||
- Its implementation for MoE can also be utilized.
|
||||
* `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model
|
||||
- inferencing cost is about 0.5x, and MoE is not implemented.
|
||||
|
||||
If you're training a true foundational model, consider which backend you want to use the most. `llama` backends can benefit from all the additional tech with it, while exotic ones like `retnet` or `bitnet` can't at the moment, but may leverage experimental gains.
|
||||
It's recommended to use `llama` with `xformers`-based attention, as the savings are huge in comparison to even `retnet`-backed models.
|
||||
|
||||
## Export
|
||||
|
||||
|
|
|
@ -201,7 +201,7 @@ class Engine():
|
|||
|
||||
def _get_grad_norm(self):
|
||||
t = [ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]
|
||||
self._global_grad_norm = torch.cat(t).norm().item() if len(t) else 0
|
||||
self._global_grad_norm = torch.cat(t).norm().item() if len(t) else None
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
|
@ -478,14 +478,13 @@ class Engines(dict[str, Engine]):
|
|||
flatten_dict(
|
||||
{
|
||||
name.split("-")[0]: dict(
|
||||
loss=loss.item(),
|
||||
**engine_stats,
|
||||
lr=engine.get_lr()[0],
|
||||
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
|
||||
grad_norm=engine.get_global_grad_norm(),
|
||||
elapsed_time=elapsed_time,
|
||||
engine_step=engine.global_step,
|
||||
samples_processed=engine.global_samples,
|
||||
tokens_processed=engine.tokens_processed,
|
||||
**engine_stats,
|
||||
)
|
||||
}
|
||||
),
|
||||
|
|
Loading…
Reference in New Issue
Block a user