Yet Another Underlying Transformer Implementation (BitNet, will give it a few days to see how it fares)
This commit is contained in:
parent
3da1518ace
commit
35d78a2bb0
28
README.md
28
README.md
|
@ -6,9 +6,9 @@
|
||||||
|
|
||||||
An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder.
|
An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder.
|
||||||
|
|
||||||
[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/) | [HuggingFace Space](https://huggingface.co/spaces/ecker/vall-e)
|
[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/)
|
||||||
|
|
||||||
> **Note** This README is still quite a disorganized mess.
|
> **Note** Development on this is very sporadic. Gomen.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
|
||||||
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
|
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
|
||||||
- Linux users can consult their package managers on installing `espeak`/`espeak-ng`.
|
- Linux users can consult their package managers on installing `espeak`/`espeak-ng`.
|
||||||
- Windows users are required to install [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/releases/tag/1.51#Assets).
|
- Windows users are required to install [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/releases/tag/1.51#Assets).
|
||||||
+ additionally, you may be require dto set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`.
|
+ additionally, you may be required to set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`.
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
|
@ -30,12 +30,6 @@ I've tested this repo under Python versions `3.10.9` and `3.11.3`.
|
||||||
|
|
||||||
## Try Me
|
## Try Me
|
||||||
|
|
||||||
### Online
|
|
||||||
|
|
||||||
A HuggingFace space hosting the code and models can be found [here](https://huggingface.co/spaces/ecker/vall-e).
|
|
||||||
|
|
||||||
### Local
|
|
||||||
|
|
||||||
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`
|
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`
|
||||||
|
|
||||||
Each model file has a barebones trainer and inference routine.
|
Each model file has a barebones trainer and inference routine.
|
||||||
|
@ -52,7 +46,7 @@ Training is very dependent on:
|
||||||
* the quality of your dataset.
|
* the quality of your dataset.
|
||||||
* how much data you have.
|
* how much data you have.
|
||||||
* the bandwidth you quantized your audio to.
|
* the bandwidth you quantized your audio to.
|
||||||
* the underlying model architecture used
|
* the underlying model architecture used.
|
||||||
|
|
||||||
### Pre-Processed Dataset
|
### Pre-Processed Dataset
|
||||||
|
|
||||||
|
@ -105,12 +99,24 @@ Keep in mind that creature comforts like distributed training or `float16` train
|
||||||
|
|
||||||
#### Training on Low-VRAM Cards
|
#### Training on Low-VRAM Cards
|
||||||
|
|
||||||
During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16.
|
During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16, albeit I feel this is mostly snakeoil. Better VRAM savings can be had with use of BitsAndBytes and their respective flags (specifically its AdamW implementation).
|
||||||
|
|
||||||
VRAM use is also predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful.
|
VRAM use is also predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful.
|
||||||
|
|
||||||
Additionally, under Windows, I managed to finetune the AR on my 2060 (6GiB VRAM) with a batch size of 8 (although, with the card as a secondary GPU).
|
Additionally, under Windows, I managed to finetune the AR on my 2060 (6GiB VRAM) with a batch size of 8 (although, with the card as a secondary GPU).
|
||||||
|
|
||||||
|
#### Backend Architectures
|
||||||
|
|
||||||
|
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported:
|
||||||
|
|
||||||
|
* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards.
|
||||||
|
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
|
||||||
|
- Its implementation for MoE can also be utilized.
|
||||||
|
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
|
||||||
|
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
|
||||||
|
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
|
||||||
|
- Setting `bitsandbytes.bitnet=True` will make use of BitNet's linear implementation.
|
||||||
|
|
||||||
## Export
|
## Export
|
||||||
|
|
||||||
To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.
|
To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.
|
||||||
|
|
|
@ -532,6 +532,8 @@ class BitsAndBytes:
|
||||||
linear: bool = True
|
linear: bool = True
|
||||||
embedding: bool = True
|
embedding: bool = True
|
||||||
|
|
||||||
|
bitnet: bool = False
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Config(_Config):
|
class Config(_Config):
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
|
|
|
@ -362,8 +362,8 @@ def example_usage():
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
steps = 500
|
steps = 500
|
||||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
#optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||||
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
torch.save( {
|
torch.save( {
|
||||||
|
|
|
@ -40,6 +40,19 @@ except Exception as e:
|
||||||
print("Error importing `mistral` arch:", e)
|
print("Error importing `mistral` arch:", e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from bitnet import BitNetTransformer
|
||||||
|
|
||||||
|
def NoEmbedding_BitNetTransformer_Forward(self, x):
|
||||||
|
x = self.transformer(x)
|
||||||
|
return self.to_logits[0](x)
|
||||||
|
|
||||||
|
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("Error importing `bitnet` arch:", e)
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import MixtralModel, MixtralConfig
|
from transformers import MixtralModel, MixtralConfig
|
||||||
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
|
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
|
||||||
|
@ -325,7 +338,7 @@ class Base(nn.Module):
|
||||||
norm_type=self.norm_type,
|
norm_type=self.norm_type,
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
) for _ in range(n_layers) ])
|
) for _ in range(n_layers) ])
|
||||||
elif self.arch_type == "mistral":
|
elif self.arch_type == "mistral" or self.arch_type == "mixtral":
|
||||||
if n_experts <= 1:
|
if n_experts <= 1:
|
||||||
self.model = MistralModel(MistralConfig(
|
self.model = MistralModel(MistralConfig(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
|
@ -425,6 +438,16 @@ class Base(nn.Module):
|
||||||
))
|
))
|
||||||
|
|
||||||
self.model = RetNetDecoder(RetNetConfig(**kwargs))
|
self.model = RetNetDecoder(RetNetConfig(**kwargs))
|
||||||
|
elif self.arch_type == "bitnet":
|
||||||
|
self.model = BitNetTransformer(
|
||||||
|
num_tokens=n_resp_tokens,
|
||||||
|
dim=d_model,
|
||||||
|
depth=n_layers,
|
||||||
|
heads=n_heads,
|
||||||
|
ff_mult=4,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||||
|
|
||||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||||
|
|
||||||
|
@ -486,7 +509,7 @@ class Base(nn.Module):
|
||||||
# grab last token(s)
|
# grab last token(s)
|
||||||
x = x[:, -1, :].unsqueeze(1)
|
x = x[:, -1, :].unsqueeze(1)
|
||||||
# HF transformer derived model
|
# HF transformer derived model
|
||||||
elif self.arch_type == "llama" or self.arch_type == "mistral":
|
elif self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral":
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
#attention_mask=m,
|
#attention_mask=m,
|
||||||
inputs_embeds=x,
|
inputs_embeds=x,
|
||||||
|
@ -521,6 +544,8 @@ class Base(nn.Module):
|
||||||
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
|
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||||
if _ is not None and "l_aux" in _ and self.n_experts > 1:
|
if _ is not None and "l_aux" in _ and self.n_experts > 1:
|
||||||
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
|
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
|
||||||
|
elif self.arch_type == "bitnet":
|
||||||
|
x = self.model(x)
|
||||||
# output projection layer with masking
|
# output projection layer with masking
|
||||||
x = self.classifier(x) * m
|
x = self.classifier(x) * m
|
||||||
|
|
||||||
|
|
|
@ -7,10 +7,18 @@ from ..config import cfg
|
||||||
Embedding = torch.nn.Embedding
|
Embedding = torch.nn.Embedding
|
||||||
Linear = torch.nn.Linear
|
Linear = torch.nn.Linear
|
||||||
|
|
||||||
|
# https://github.com/kyegomez/BitNet
|
||||||
|
if cfg.bitsandbytes.bitnet:
|
||||||
|
from bitnet import BitLinear
|
||||||
|
|
||||||
if cfg.bitsandbytes.enabled:
|
if cfg.bitsandbytes.enabled:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
if cfg.bitsandbytes.linear:
|
if cfg.bitsandbytes.linear:
|
||||||
|
|
||||||
|
if cfg.bitsandbytes.bitnet:
|
||||||
|
Linear = BitLinear
|
||||||
|
else:
|
||||||
Linear = bnb.nn.Linear8bitLt
|
Linear = bnb.nn.Linear8bitLt
|
||||||
|
|
||||||
if cfg.bitsandbytes.embedding:
|
if cfg.bitsandbytes.embedding:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user