Yet Another Underlying Transformer Implementation (BitNet, will give it a few days to see how it fares)

This commit is contained in:
mrq 2024-02-29 20:29:17 -06:00
parent 3da1518ace
commit 35d78a2bb0
5 changed files with 57 additions and 16 deletions

View File

@ -6,9 +6,9 @@
An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder.
[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/) | [HuggingFace Space](https://huggingface.co/spaces/ecker/vall-e)
[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/)
> **Note** This README is still quite a disorganized mess.
> **Note** Development on this is very sporadic. Gomen.
## Requirements
@ -20,7 +20,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
- Linux users can consult their package managers on installing `espeak`/`espeak-ng`.
- Windows users are required to install [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/releases/tag/1.51#Assets).
+ additionally, you may be require dto set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`.
+ additionally, you may be required to set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`.
## Install
@ -30,12 +30,6 @@ I've tested this repo under Python versions `3.10.9` and `3.11.3`.
## Try Me
### Online
A HuggingFace space hosting the code and models can be found [here](https://huggingface.co/spaces/ecker/vall-e).
### Local
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`
Each model file has a barebones trainer and inference routine.
@ -52,7 +46,7 @@ Training is very dependent on:
* the quality of your dataset.
* how much data you have.
* the bandwidth you quantized your audio to.
* the underlying model architecture used
* the underlying model architecture used.
### Pre-Processed Dataset
@ -105,12 +99,24 @@ Keep in mind that creature comforts like distributed training or `float16` train
#### Training on Low-VRAM Cards
During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16.
During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16, albeit I feel this is mostly snakeoil. Better VRAM savings can be had with use of BitsAndBytes and their respective flags (specifically its AdamW implementation).
VRAM use is also predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful.
Additionally, under Windows, I managed to finetune the AR on my 2060 (6GiB VRAM) with a batch size of 8 (although, with the card as a secondary GPU).
#### Backend Architectures
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported:
* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards.
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
- Its implementation for MoE can also be utilized.
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
- Setting `bitsandbytes.bitnet=True` will make use of BitNet's linear implementation.
## Export
To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.

View File

@ -531,6 +531,8 @@ class BitsAndBytes:
linear: bool = True
embedding: bool = True
bitnet: bool = False
@dataclass()
class Config(_Config):

View File

@ -362,8 +362,8 @@ def example_usage():
model = AR_NAR(**kwargs).to(device)
steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
#optimizer = ml.Prodigy(model.parameters(), lr=1.0)
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
engine = Engine(model=model, optimizer=optimizer)
torch.save( {

View File

@ -40,6 +40,19 @@ except Exception as e:
print("Error importing `mistral` arch:", e)
pass
try:
from bitnet import BitNetTransformer
def NoEmbedding_BitNetTransformer_Forward(self, x):
x = self.transformer(x)
return self.to_logits[0](x)
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
except Exception as e:
print("Error importing `bitnet` arch:", e)
pass
try:
from transformers import MixtralModel, MixtralConfig
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
@ -325,7 +338,7 @@ class Base(nn.Module):
norm_type=self.norm_type,
n_levels=self.n_resp_levels,
) for _ in range(n_layers) ])
elif self.arch_type == "mistral":
elif self.arch_type == "mistral" or self.arch_type == "mixtral":
if n_experts <= 1:
self.model = MistralModel(MistralConfig(
vocab_size=n_resp_tokens,
@ -425,6 +438,16 @@ class Base(nn.Module):
))
self.model = RetNetDecoder(RetNetConfig(**kwargs))
elif self.arch_type == "bitnet":
self.model = BitNetTransformer(
num_tokens=n_resp_tokens,
dim=d_model,
depth=n_layers,
heads=n_heads,
ff_mult=4,
)
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
self.classifier = nn.Linear(d_model, n_resp_tokens)
@ -486,7 +509,7 @@ class Base(nn.Module):
# grab last token(s)
x = x[:, -1, :].unsqueeze(1)
# HF transformer derived model
elif self.arch_type == "llama" or self.arch_type == "mistral":
elif self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral":
kwargs = dict(
#attention_mask=m,
inputs_embeds=x,
@ -521,6 +544,8 @@ class Base(nn.Module):
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
if _ is not None and "l_aux" in _ and self.n_experts > 1:
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
elif self.arch_type == "bitnet":
x = self.model(x)
# output projection layer with masking
x = self.classifier(x) * m

View File

@ -7,11 +7,19 @@ from ..config import cfg
Embedding = torch.nn.Embedding
Linear = torch.nn.Linear
# https://github.com/kyegomez/BitNet
if cfg.bitsandbytes.bitnet:
from bitnet import BitLinear
if cfg.bitsandbytes.enabled:
import bitsandbytes as bnb
if cfg.bitsandbytes.linear:
Linear = bnb.nn.Linear8bitLt
if cfg.bitsandbytes.bitnet:
Linear = BitLinear
else:
Linear = bnb.nn.Linear8bitLt
if cfg.bitsandbytes.embedding:
Embedding = bnb.nn.modules.Embedding