Yet Another Underlying Transformer Implementation (BitNet, will give it a few days to see how it fares)
This commit is contained in:
parent
3da1518ace
commit
35d78a2bb0
28
README.md
28
README.md
|
@ -6,9 +6,9 @@
|
|||
|
||||
An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder.
|
||||
|
||||
[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/) | [HuggingFace Space](https://huggingface.co/spaces/ecker/vall-e)
|
||||
[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/)
|
||||
|
||||
> **Note** This README is still quite a disorganized mess.
|
||||
> **Note** Development on this is very sporadic. Gomen.
|
||||
|
||||
## Requirements
|
||||
|
||||
|
@ -20,7 +20,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
|
|||
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
|
||||
- Linux users can consult their package managers on installing `espeak`/`espeak-ng`.
|
||||
- Windows users are required to install [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/releases/tag/1.51#Assets).
|
||||
+ additionally, you may be require dto set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`.
|
||||
+ additionally, you may be required to set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`.
|
||||
|
||||
## Install
|
||||
|
||||
|
@ -30,12 +30,6 @@ I've tested this repo under Python versions `3.10.9` and `3.11.3`.
|
|||
|
||||
## Try Me
|
||||
|
||||
### Online
|
||||
|
||||
A HuggingFace space hosting the code and models can be found [here](https://huggingface.co/spaces/ecker/vall-e).
|
||||
|
||||
### Local
|
||||
|
||||
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`
|
||||
|
||||
Each model file has a barebones trainer and inference routine.
|
||||
|
@ -52,7 +46,7 @@ Training is very dependent on:
|
|||
* the quality of your dataset.
|
||||
* how much data you have.
|
||||
* the bandwidth you quantized your audio to.
|
||||
* the underlying model architecture used
|
||||
* the underlying model architecture used.
|
||||
|
||||
### Pre-Processed Dataset
|
||||
|
||||
|
@ -105,12 +99,24 @@ Keep in mind that creature comforts like distributed training or `float16` train
|
|||
|
||||
#### Training on Low-VRAM Cards
|
||||
|
||||
During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16.
|
||||
During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16, albeit I feel this is mostly snakeoil. Better VRAM savings can be had with use of BitsAndBytes and their respective flags (specifically its AdamW implementation).
|
||||
|
||||
VRAM use is also predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful.
|
||||
|
||||
Additionally, under Windows, I managed to finetune the AR on my 2060 (6GiB VRAM) with a batch size of 8 (although, with the card as a secondary GPU).
|
||||
|
||||
#### Backend Architectures
|
||||
|
||||
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported:
|
||||
|
||||
* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards.
|
||||
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
|
||||
- Its implementation for MoE can also be utilized.
|
||||
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
|
||||
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
|
||||
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
|
||||
- Setting `bitsandbytes.bitnet=True` will make use of BitNet's linear implementation.
|
||||
|
||||
## Export
|
||||
|
||||
To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.
|
||||
|
|
|
@ -531,6 +531,8 @@ class BitsAndBytes:
|
|||
|
||||
linear: bool = True
|
||||
embedding: bool = True
|
||||
|
||||
bitnet: bool = False
|
||||
|
||||
@dataclass()
|
||||
class Config(_Config):
|
||||
|
|
|
@ -362,8 +362,8 @@ def example_usage():
|
|||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 500
|
||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||
#optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
torch.save( {
|
||||
|
|
|
@ -40,6 +40,19 @@ except Exception as e:
|
|||
print("Error importing `mistral` arch:", e)
|
||||
pass
|
||||
|
||||
try:
|
||||
from bitnet import BitNetTransformer
|
||||
|
||||
def NoEmbedding_BitNetTransformer_Forward(self, x):
|
||||
x = self.transformer(x)
|
||||
return self.to_logits[0](x)
|
||||
|
||||
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
|
||||
|
||||
except Exception as e:
|
||||
print("Error importing `bitnet` arch:", e)
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers import MixtralModel, MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
|
||||
|
@ -325,7 +338,7 @@ class Base(nn.Module):
|
|||
norm_type=self.norm_type,
|
||||
n_levels=self.n_resp_levels,
|
||||
) for _ in range(n_layers) ])
|
||||
elif self.arch_type == "mistral":
|
||||
elif self.arch_type == "mistral" or self.arch_type == "mixtral":
|
||||
if n_experts <= 1:
|
||||
self.model = MistralModel(MistralConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
|
@ -425,6 +438,16 @@ class Base(nn.Module):
|
|||
))
|
||||
|
||||
self.model = RetNetDecoder(RetNetConfig(**kwargs))
|
||||
elif self.arch_type == "bitnet":
|
||||
self.model = BitNetTransformer(
|
||||
num_tokens=n_resp_tokens,
|
||||
dim=d_model,
|
||||
depth=n_layers,
|
||||
heads=n_heads,
|
||||
ff_mult=4,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
|
||||
|
@ -486,7 +509,7 @@ class Base(nn.Module):
|
|||
# grab last token(s)
|
||||
x = x[:, -1, :].unsqueeze(1)
|
||||
# HF transformer derived model
|
||||
elif self.arch_type == "llama" or self.arch_type == "mistral":
|
||||
elif self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral":
|
||||
kwargs = dict(
|
||||
#attention_mask=m,
|
||||
inputs_embeds=x,
|
||||
|
@ -521,6 +544,8 @@ class Base(nn.Module):
|
|||
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||
if _ is not None and "l_aux" in _ and self.n_experts > 1:
|
||||
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
|
||||
elif self.arch_type == "bitnet":
|
||||
x = self.model(x)
|
||||
# output projection layer with masking
|
||||
x = self.classifier(x) * m
|
||||
|
||||
|
|
|
@ -7,11 +7,19 @@ from ..config import cfg
|
|||
Embedding = torch.nn.Embedding
|
||||
Linear = torch.nn.Linear
|
||||
|
||||
# https://github.com/kyegomez/BitNet
|
||||
if cfg.bitsandbytes.bitnet:
|
||||
from bitnet import BitLinear
|
||||
|
||||
if cfg.bitsandbytes.enabled:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
if cfg.bitsandbytes.linear:
|
||||
Linear = bnb.nn.Linear8bitLt
|
||||
|
||||
if cfg.bitsandbytes.bitnet:
|
||||
Linear = BitLinear
|
||||
else:
|
||||
Linear = bnb.nn.Linear8bitLt
|
||||
|
||||
if cfg.bitsandbytes.embedding:
|
||||
Embedding = bnb.nn.modules.Embedding
|
||||
|
|
Loading…
Reference in New Issue
Block a user