Yet Another Underlying Transformer Implementation (BitNet, will give it a few days to see how it fares)

This commit is contained in:
mrq 2024-02-29 20:29:17 -06:00
parent 3da1518ace
commit 35d78a2bb0
5 changed files with 57 additions and 16 deletions

View File

@ -6,9 +6,9 @@
An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder. An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder.
[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/) | [HuggingFace Space](https://huggingface.co/spaces/ecker/vall-e) [Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/)
> **Note** This README is still quite a disorganized mess. > **Note** Development on this is very sporadic. Gomen.
## Requirements ## Requirements
@ -20,7 +20,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed. - For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
- Linux users can consult their package managers on installing `espeak`/`espeak-ng`. - Linux users can consult their package managers on installing `espeak`/`espeak-ng`.
- Windows users are required to install [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/releases/tag/1.51#Assets). - Windows users are required to install [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/releases/tag/1.51#Assets).
+ additionally, you may be require dto set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`. + additionally, you may be required to set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`.
## Install ## Install
@ -30,12 +30,6 @@ I've tested this repo under Python versions `3.10.9` and `3.11.3`.
## Try Me ## Try Me
### Online
A HuggingFace space hosting the code and models can be found [here](https://huggingface.co/spaces/ecker/vall-e).
### Local
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"` To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`
Each model file has a barebones trainer and inference routine. Each model file has a barebones trainer and inference routine.
@ -52,7 +46,7 @@ Training is very dependent on:
* the quality of your dataset. * the quality of your dataset.
* how much data you have. * how much data you have.
* the bandwidth you quantized your audio to. * the bandwidth you quantized your audio to.
* the underlying model architecture used * the underlying model architecture used.
### Pre-Processed Dataset ### Pre-Processed Dataset
@ -105,12 +99,24 @@ Keep in mind that creature comforts like distributed training or `float16` train
#### Training on Low-VRAM Cards #### Training on Low-VRAM Cards
During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16. During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16, albeit I feel this is mostly snakeoil. Better VRAM savings can be had with use of BitsAndBytes and their respective flags (specifically its AdamW implementation).
VRAM use is also predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful. VRAM use is also predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful.
Additionally, under Windows, I managed to finetune the AR on my 2060 (6GiB VRAM) with a batch size of 8 (although, with the card as a secondary GPU). Additionally, under Windows, I managed to finetune the AR on my 2060 (6GiB VRAM) with a batch size of 8 (although, with the card as a secondary GPU).
#### Backend Architectures
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported:
* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards.
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
- Its implementation for MoE can also be utilized.
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
- Setting `bitsandbytes.bitnet=True` will make use of BitNet's linear implementation.
## Export ## Export
To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`. To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.

View File

@ -532,6 +532,8 @@ class BitsAndBytes:
linear: bool = True linear: bool = True
embedding: bool = True embedding: bool = True
bitnet: bool = False
@dataclass() @dataclass()
class Config(_Config): class Config(_Config):
device: str = "cuda" device: str = "cuda"

View File

@ -362,8 +362,8 @@ def example_usage():
model = AR_NAR(**kwargs).to(device) model = AR_NAR(**kwargs).to(device)
steps = 500 steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0) #optimizer = ml.Prodigy(model.parameters(), lr=1.0)
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
engine = Engine(model=model, optimizer=optimizer) engine = Engine(model=model, optimizer=optimizer)
torch.save( { torch.save( {

View File

@ -40,6 +40,19 @@ except Exception as e:
print("Error importing `mistral` arch:", e) print("Error importing `mistral` arch:", e)
pass pass
try:
from bitnet import BitNetTransformer
def NoEmbedding_BitNetTransformer_Forward(self, x):
x = self.transformer(x)
return self.to_logits[0](x)
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
except Exception as e:
print("Error importing `bitnet` arch:", e)
pass
try: try:
from transformers import MixtralModel, MixtralConfig from transformers import MixtralModel, MixtralConfig
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
@ -325,7 +338,7 @@ class Base(nn.Module):
norm_type=self.norm_type, norm_type=self.norm_type,
n_levels=self.n_resp_levels, n_levels=self.n_resp_levels,
) for _ in range(n_layers) ]) ) for _ in range(n_layers) ])
elif self.arch_type == "mistral": elif self.arch_type == "mistral" or self.arch_type == "mixtral":
if n_experts <= 1: if n_experts <= 1:
self.model = MistralModel(MistralConfig( self.model = MistralModel(MistralConfig(
vocab_size=n_resp_tokens, vocab_size=n_resp_tokens,
@ -425,6 +438,16 @@ class Base(nn.Module):
)) ))
self.model = RetNetDecoder(RetNetConfig(**kwargs)) self.model = RetNetDecoder(RetNetConfig(**kwargs))
elif self.arch_type == "bitnet":
self.model = BitNetTransformer(
num_tokens=n_resp_tokens,
dim=d_model,
depth=n_layers,
heads=n_heads,
ff_mult=4,
)
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
self.classifier = nn.Linear(d_model, n_resp_tokens) self.classifier = nn.Linear(d_model, n_resp_tokens)
@ -486,7 +509,7 @@ class Base(nn.Module):
# grab last token(s) # grab last token(s)
x = x[:, -1, :].unsqueeze(1) x = x[:, -1, :].unsqueeze(1)
# HF transformer derived model # HF transformer derived model
elif self.arch_type == "llama" or self.arch_type == "mistral": elif self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral":
kwargs = dict( kwargs = dict(
#attention_mask=m, #attention_mask=m,
inputs_embeds=x, inputs_embeds=x,
@ -521,6 +544,8 @@ class Base(nn.Module):
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True) x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
if _ is not None and "l_aux" in _ and self.n_experts > 1: if _ is not None and "l_aux" in _ and self.n_experts > 1:
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001 aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
elif self.arch_type == "bitnet":
x = self.model(x)
# output projection layer with masking # output projection layer with masking
x = self.classifier(x) * m x = self.classifier(x) * m

View File

@ -7,10 +7,18 @@ from ..config import cfg
Embedding = torch.nn.Embedding Embedding = torch.nn.Embedding
Linear = torch.nn.Linear Linear = torch.nn.Linear
# https://github.com/kyegomez/BitNet
if cfg.bitsandbytes.bitnet:
from bitnet import BitLinear
if cfg.bitsandbytes.enabled: if cfg.bitsandbytes.enabled:
import bitsandbytes as bnb import bitsandbytes as bnb
if cfg.bitsandbytes.linear: if cfg.bitsandbytes.linear:
if cfg.bitsandbytes.bitnet:
Linear = BitLinear
else:
Linear = bnb.nn.Linear8bitLt Linear = bnb.nn.Linear8bitLt
if cfg.bitsandbytes.embedding: if cfg.bitsandbytes.embedding: