From 35d78a2bb0e05c933a12d82f079dbf21cb5fa23e Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 29 Feb 2024 20:29:17 -0600 Subject: [PATCH] Yet Another Underlying Transformer Implementation (BitNet, will give it a few days to see how it fares) --- README.md | 28 +++++++++++++++++----------- vall_e/config.py | 2 ++ vall_e/models/ar_nar.py | 4 ++-- vall_e/models/base.py | 29 +++++++++++++++++++++++++++-- vall_e/utils/wrapper.py | 10 +++++++++- 5 files changed, 57 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index bd2d96d..75106d1 100755 --- a/README.md +++ b/README.md @@ -6,9 +6,9 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder. -[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/) | [HuggingFace Space](https://huggingface.co/spaces/ecker/vall-e) +[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/) -> **Note** This README is still quite a disorganized mess. +> **Note** Development on this is very sporadic. Gomen. ## Requirements @@ -20,7 +20,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), - For phonemizing text, this repo requires `espeak`/`espeak-ng` installed. - Linux users can consult their package managers on installing `espeak`/`espeak-ng`. - Windows users are required to install [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/releases/tag/1.51#Assets). - + additionally, you may be require dto set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`. + + additionally, you may be required to set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`. ## Install @@ -30,12 +30,6 @@ I've tested this repo under Python versions `3.10.9` and `3.11.3`. ## Try Me -### Online - -A HuggingFace space hosting the code and models can be found [here](https://huggingface.co/spaces/ecker/vall-e). - -### Local - To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"` Each model file has a barebones trainer and inference routine. @@ -52,7 +46,7 @@ Training is very dependent on: * the quality of your dataset. * how much data you have. * the bandwidth you quantized your audio to. -* the underlying model architecture used +* the underlying model architecture used. ### Pre-Processed Dataset @@ -105,12 +99,24 @@ Keep in mind that creature comforts like distributed training or `float16` train #### Training on Low-VRAM Cards -During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16. +During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16, albeit I feel this is mostly snakeoil. Better VRAM savings can be had with use of BitsAndBytes and their respective flags (specifically its AdamW implementation). VRAM use is also predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful. Additionally, under Windows, I managed to finetune the AR on my 2060 (6GiB VRAM) with a batch size of 8 (although, with the card as a secondary GPU). +#### Backend Architectures + +As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported: + +* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards. +* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead. + - Its implementation for MoE can also be utilized. +* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements. +* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation. +* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer. + - Setting `bitsandbytes.bitnet=True` will make use of BitNet's linear implementation. + ## Export To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`. diff --git a/vall_e/config.py b/vall_e/config.py index 3aa83c3..fca6736 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -531,6 +531,8 @@ class BitsAndBytes: linear: bool = True embedding: bool = True + + bitnet: bool = False @dataclass() class Config(_Config): diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 7b7ee9a..3b5d875 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -362,8 +362,8 @@ def example_usage(): model = AR_NAR(**kwargs).to(device) steps = 500 - optimizer = ml.Prodigy(model.parameters(), lr=1.0) - #optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) + #optimizer = ml.Prodigy(model.parameters(), lr=1.0) + optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) engine = Engine(model=model, optimizer=optimizer) torch.save( { diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 220751c..26eb385 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -40,6 +40,19 @@ except Exception as e: print("Error importing `mistral` arch:", e) pass +try: + from bitnet import BitNetTransformer + + def NoEmbedding_BitNetTransformer_Forward(self, x): + x = self.transformer(x) + return self.to_logits[0](x) + + BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward + +except Exception as e: + print("Error importing `bitnet` arch:", e) + pass + try: from transformers import MixtralModel, MixtralConfig from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock @@ -325,7 +338,7 @@ class Base(nn.Module): norm_type=self.norm_type, n_levels=self.n_resp_levels, ) for _ in range(n_layers) ]) - elif self.arch_type == "mistral": + elif self.arch_type == "mistral" or self.arch_type == "mixtral": if n_experts <= 1: self.model = MistralModel(MistralConfig( vocab_size=n_resp_tokens, @@ -425,6 +438,16 @@ class Base(nn.Module): )) self.model = RetNetDecoder(RetNetConfig(**kwargs)) + elif self.arch_type == "bitnet": + self.model = BitNetTransformer( + num_tokens=n_resp_tokens, + dim=d_model, + depth=n_layers, + heads=n_heads, + ff_mult=4, + ) + else: + raise RuntimeError(f'Unknown arch specified: {self.arch_type}') self.classifier = nn.Linear(d_model, n_resp_tokens) @@ -486,7 +509,7 @@ class Base(nn.Module): # grab last token(s) x = x[:, -1, :].unsqueeze(1) # HF transformer derived model - elif self.arch_type == "llama" or self.arch_type == "mistral": + elif self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral": kwargs = dict( #attention_mask=m, inputs_embeds=x, @@ -521,6 +544,8 @@ class Base(nn.Module): x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True) if _ is not None and "l_aux" in _ and self.n_experts > 1: aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001 + elif self.arch_type == "bitnet": + x = self.model(x) # output projection layer with masking x = self.classifier(x) * m diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index 62ac50e..b00399d 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -7,11 +7,19 @@ from ..config import cfg Embedding = torch.nn.Embedding Linear = torch.nn.Linear +# https://github.com/kyegomez/BitNet +if cfg.bitsandbytes.bitnet: + from bitnet import BitLinear + if cfg.bitsandbytes.enabled: import bitsandbytes as bnb if cfg.bitsandbytes.linear: - Linear = bnb.nn.Linear8bitLt + + if cfg.bitsandbytes.bitnet: + Linear = BitLinear + else: + Linear = bnb.nn.Linear8bitLt if cfg.bitsandbytes.embedding: Embedding = bnb.nn.modules.Embedding