Compare commits

...

13 Commits

Author SHA1 Message Date
mrq 7075c2a5f0 added an option to allow injecting embeddings from another model, because it dawned upon me how valuable embeddings from a good model can be for subsequent trainings (defined under cfg.models._embeddings as a relative path to the yaml) 2024-04-04 19:11:49 +07:00
mrq 91062361af tweaks 2024-03-01 20:38:06 +07:00
mrq f3c59c3e7e cleaner replacement code (because I realized BitNet had an implementation for it too), added calculating gradient norm and performing gradient clipping in local trainer (non-deepspeed) 2024-03-01 20:18:43 +07:00
mrq 47435207f7 Added cfg.bitsandbytes.replace as a less intrusive alternative to cfg.bitsandbytes.inject to replace all Linear modules in a model 2024-03-01 19:20:10 +07:00
mrq 0427d8d076 logger broke for some reason, added flag to just tqdm.write instead, make cfg.bitsandbytes.bitnet==True yamls denoted since I'm sure they're not interoperable 2024-03-01 10:32:35 +07:00
mrq 35d78a2bb0 Yet Another Underlying Transformer Implementation (BitNet, will give it a few days to see how it fares) 2024-02-29 20:29:17 +07:00
mrq 3da1518ace added Mistral (non-Mixtral) backend, useless optimization when not training, proper adjustment of the LR for Prodigyopt through d_coeff (maybe), recurrent sampling for LLaMA/Mistral/Mixtral backends (again, doesn't actually work) 2024-01-31 21:48:36 +07:00
mrq cce929e136 nasty hotfix for transformer's Mixtral throwing an error when batch sizes > 1 2024-01-26 19:41:12 +07:00
mrq e799665759 experimental weighting of prom/resp embeds 2024-01-25 12:18:48 +07:00
mrq c690aa509d fixes and compat (MoE-fying an existing model and retraining from there just ruins it after a second of audio...) 2023-12-25 21:20:32 +07:00
mrq e513d2ef19 experts weren't forwarded into constructer (wasted a few days of training garbage) 2023-12-23 16:08:17 +07:00
mrq 0db3203b21 added LLaMA/Mixtral (if experts>1) model arches, utilize XMoE's loss as well, set MoE frequency to 1 to make every layer MoE'd for RetNet, etc. (going to do tests without burning out again to see how things go) 2023-12-22 19:27:36 +07:00
mrq 9c198eb75a added torchscale XMOE integration (because Mixtral 8x7B seems very promising and I want to see if it works) 2023-12-20 18:45:58 +07:00
18 changed files with 546 additions and 108 deletions

@ -6,9 +6,9 @@
An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder.
[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/) | [HuggingFace Space](https://huggingface.co/spaces/ecker/vall-e)
[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/)
> **Note** This README is still quite a disorganized mess.
> **Note** Development on this is very sporadic. Gomen.
## Requirements
@ -20,7 +20,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
- Linux users can consult their package managers on installing `espeak`/`espeak-ng`.
- Windows users are required to install [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/releases/tag/1.51#Assets).
+ additionally, you may be require dto set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`.
+ additionally, you may be required to set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`.
## Install
@ -30,12 +30,6 @@ I've tested this repo under Python versions `3.10.9` and `3.11.3`.
## Try Me
### Online
A HuggingFace space hosting the code and models can be found [here](https://huggingface.co/spaces/ecker/vall-e).
### Local
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`
Each model file has a barebones trainer and inference routine.
@ -52,6 +46,7 @@ Training is very dependent on:
* the quality of your dataset.
* how much data you have.
* the bandwidth you quantized your audio to.
* the underlying model architecture used.
### Pre-Processed Dataset
@ -104,12 +99,24 @@ Keep in mind that creature comforts like distributed training or `float16` train
#### Training on Low-VRAM Cards
During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16.
During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16, albeit I feel this is mostly snakeoil. Better VRAM savings can be had with use of BitsAndBytes and their respective flags (specifically its AdamW implementation).
VRAM use is also predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful.
Additionally, under Windows, I managed to finetune the AR on my 2060 (6GiB VRAM) with a batch size of 8 (although, with the card as a secondary GPU).
#### Backend Architectures
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported:
* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards.
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
- Its implementation for MoE can also be utilized.
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
- Setting `bitsandbytes.bitnet=True` will make use of BitNet's linear implementation.
## Export
To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.

@ -48,6 +48,7 @@ setup(
"omegaconf==2.0.6",
"tqdm>=4.64.1",
"humanize>=4.4.0",
"transformers>4.37.0",
"pandas>=1.5.0",
"torch>=1.13.0",

@ -169,6 +169,7 @@ class Model:
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
langs: int = 1 # defined languages
experts: int = 1
arch_type: str = "retnet" # or "transformer""
training: bool = True # unneeded now
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
@ -183,12 +184,19 @@ class Model:
name.append(self.size)
if self.arch_type != "transformer":
name.append(self.arch_type.replace("/", "-"))
if self.experts > 1:
name.append(f'{self.experts}x'+self.arch_type.replace("/", "-"))
else:
name.append(self.arch_type.replace("/", "-"))
if cfg.bitsandbytes.bitnet:
name.append("bitnet")
if self.interleave:
name.append("interleaved")
else:
name.append(f'{cfg.models.prom_levels}')
name.append(f'{cfg.models.prom_levels}')
return "-".join(name)
@ -245,10 +253,11 @@ class Model:
class Models:
_max_levels: int = 0
_prom_levels: int = 1
_embeddings: str | None = None
_models: list[Model] = field(default_factory=lambda: [
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, training=True, interleave=False),
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, training=True, interleave=False),
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False),
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False),
])
def get(self, name=None):
@ -295,7 +304,7 @@ class Models:
class Hyperparameters:
batch_size: int = 8
gradient_accumulation_steps: int = 32
gradient_clipping: int = 100
gradient_clipping: int | float = 100
optimizer: str = "Adamw"
torch_optimizer: bool = False
@ -483,6 +492,7 @@ class Trainer:
amp: bool = False
load_webui: bool = False
no_logger: bool = False
backend: str = "local"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
@ -523,9 +533,12 @@ class Inference:
class BitsAndBytes:
enabled: bool = False
injects: bool = False
replace: bool = False
linear: bool = True
embedding: bool = True
bitnet: bool = False
@dataclass()
class Config(_Config):

@ -224,7 +224,7 @@ class Dataset(_Dataset):
self.spkrs_by_spkr_group[spkr_group].append( spkr )
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
if self.sampler_type == "path":
@ -351,7 +351,7 @@ class Dataset(_Dataset):
# shuffle it up a bit
prom_length = 0
if cfg.experimental:
trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
trim_length = random.randint(75 * 3, 75 * 6) # [3 seconds, 6 seconds]
#trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * 75))
else:
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)

@ -13,6 +13,7 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine
from ..models import get_models
from ..utils import wrapper as ml
import torch
import re
deepspeed_available = False
try:
@ -24,8 +25,8 @@ except Exception as e:
from functools import cache
@cache
def load_engines():
models = get_models(cfg.models.get())
def load_engines(training=True):
models = get_models(cfg.models.get(), training=training)
engines = dict()
for name, model in models.items():
@ -43,6 +44,9 @@ def load_engines():
if inferencing:
model._cfg.training = False
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
model.model = ml.replace_linear( model.model )
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
optimizer_class = None
params = {
@ -58,6 +62,9 @@ def load_engines():
optimizer = ml.SGD
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
optimizer_class = ml.Prodigy
params['d_coef'] = params['lr']
params['lr'] = 1.0
else:
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
@ -90,8 +97,26 @@ def load_engines():
if "module" in state:
state = state["module"]
# maintain compat if I change variable names
insert = {}
erase = []
for k in state.keys():
key = re.sub(r'^retnet\.', "model.", k)
if k != key:
insert[key] = state[k]
erase.append(k)
for k in insert.keys():
state[k] = insert[k]
for k in erase:
del state[k]
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# deepspeed inferencing
if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
engine_class = _Engine
@ -117,6 +142,33 @@ def load_engines():
for name, engine in engines.items():
engine.freeze(freeze_all=False)
# copy embeddings if requested
if cfg.models._embeddings is not None:
embeddings_path = cfg.relpath / cfg.models._embeddings
if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
if "module" in embeddings:
embeddings = embeddings["module"]
frozen_params = set()
for k in list(embeddings.keys()):
if re.findall(r'_emb\.', k):
frozen_params.add(k)
else:
del embeddings[k]
engine.module.load_state_dict(embeddings, strict=False)
# there's definitely a much better way but I can't be assed at the moment
for name, param in engine.module.named_parameters():
if name not in frozen_params:
continue
param.requires_grad_(False)
engine._frozen_params.add(param)
#do_gc()
return engines

@ -45,7 +45,7 @@ from .base import TrainFeeder
_logger = logging.getLogger(__name__)
if not distributed_initialized() and cfg.trainer.backend == "local" and world_size() > 1:
if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1:
init_distributed(torch.distributed.init_process_group)
# A very naive engine implementation using barebones PyTorch
@ -102,6 +102,10 @@ class Engine():
@property
def gradient_accumulation_steps(self):
return cfg.hyperparameters.gradient_accumulation_steps
@property
def gradient_clipping(self):
return cfg.hyperparameters.gradient_clipping
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
@ -186,24 +190,36 @@ class Engine():
self.global_samples += self.batch_size
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping)
self.global_steps += 1
self.optimizer.step()
self.optimizer.zero_grad()
self._get_grad_norm()
def _get_grad_norm(self):
t = [ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]
self._global_grad_norm = torch.cat(t).norm().item() if len(t) else 0
def get_lr(self):
lrs = []
for param_group in self.optimizer.param_groups:
if 'lr' in param_group:
if 'd_coeff' in param_group:
lrs.append(param_group['d_coeff'])
elif 'lr' in param_group:
lrs.append(param_group['lr'])
return lrs
def set_lr(self, lr):
for param_group in self.optimizer.param_groups:
if 'lr' in param_group:
if 'd_coeff' in param_group:
param_group['d_coeff'] = lr
elif 'lr' in param_group:
param_group['lr'] = lr
def get_global_grad_norm(self):
return 0.0
return self._global_grad_norm
def traverse(self, *args, **kwargs):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):

@ -99,7 +99,7 @@ class Engine(DeepSpeedEngine):
try:
if hasattr(self.optimizer, 'param_groups'):
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr
else:
self.optimizer.set_lr(lr)
except Exception as e:

@ -73,7 +73,7 @@ class TTS():
self.ar_ckpt = ar_ckpt
self.nar_ckpt = nar_ckpt
models = get_models(cfg.models.get())
models = get_models(cfg.models.get(), training=False)
for name, model in models.items():
if name.startswith("ar"):
@ -101,7 +101,7 @@ class TTS():
self.loading = False
def load_models( self ):
engines = load_engines()
engines = load_engines(training=False)
for name, engine in engines.items():
if name.startswith("ar"):
self.ar = engine.module
@ -175,40 +175,47 @@ class TTS():
mirostat_eta=0.1,
out_path=None
):
if out_path is None:
out_path = f"./data/{cfg.start_time}.wav"
prom = self.encode_audio( references, trim_length=input_prompt_length )
phns = self.encode_text( text, language=language )
lang = self.encode_lang( language )
prom = to_device(prom, self.device).to(torch.int16)
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
lang = to_device(lang, self.device).to(torch.uint8)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
resps_list = self.ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_length_penalty=length_penalty,
sampling_beam_width=beam_width,
sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta,
)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
)
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
lines = text.split("\n")
wavs = []
sr = None
for line in lines:
if out_path is None:
out_path = f"./data/{cfg.start_time}.wav"
prom = self.encode_audio( references, trim_length=input_prompt_length )
phns = self.encode_text( line, language=language )
lang = self.encode_lang( language )
prom = to_device(prom, self.device).to(torch.int16)
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
lang = to_device(lang, self.device).to(torch.uint8)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
resps_list = self.ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_length_penalty=length_penalty,
sampling_beam_width=beam_width,
sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta,
)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
)
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
wavs.append(wav)
return (wav, sr)
return (torch.concat(wavs, dim=-1), sr)

@ -2,7 +2,7 @@ from .ar import AR
from .nar import NAR
from .ar_nar import AR_NAR
def get_model(cfg):
def get_model(cfg, training=True):
if cfg.name == "ar":
Model = AR
elif cfg.name == "nar":
@ -18,7 +18,9 @@ def get_model(cfg):
d_model=cfg.dim,
n_heads=cfg.heads,
n_layers=cfg.layers,
n_experts=cfg.experts,
training=training,
config = cfg,
)
model._cfg = cfg
@ -27,5 +29,5 @@ def get_model(cfg):
return model
def get_models(models):
return { model.full_name: get_model(model) for model in models }
def get_models(models, training=True):
return { model.full_name: get_model(model, training=training) for model in models }

@ -94,7 +94,7 @@ class AR_NAR(Base):
lang_list: list[Tensor] | None = None,
max_steps: int = 1000,
max_levels: int = 7,
max_levels: int = 0,
max_resp_context: int = -1,
sampling_temperature: float = 1.0,
@ -119,10 +119,26 @@ class AR_NAR(Base):
# is training
if n_levels == self.n_resp_levels:
# might be better to have this decided on the dataloader level
if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
if cfg.experimental and False:
# makes higher levels less likely
def generate( lo=0, hi=8 ):
index = lo
p = random.random()
for i in range(lo, hi):
if p < 1.0 / (2 ** i):
index = i
return int(index)
quant_levels = torch.Tensor([ generate(0, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
else:
quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ])
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
"""
if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
else:
quant_levels = torch.Tensor([ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ])
"""
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding
@ -150,7 +166,7 @@ class AR_NAR(Base):
)
# is NAR
if max_levels == 0:
max_levels = self.n_resp_levels
max_levels = self.n_resp_levels - 1
prev_list = resps_list
@ -212,13 +228,22 @@ class AR_NAR(Base):
else:
resps_list = self._unsqueeze_list(sequence_list)
logits = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
state=recurrent_state
)
if recurrent_state is not None:
logits, recurrent_state = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
state=recurrent_state
)
else:
logits = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
state=recurrent_state
)
r = super().sample(
logits=logits,
@ -284,6 +309,7 @@ def example_usage():
from ..engines import Engine
from tqdm import tqdm
from ..utils import wrapper as ml
import re
device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
@ -311,12 +337,23 @@ def example_usage():
proms_list = proms_list[:1]
resps_list = resps_list[:1]
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
kwargs = {
'n_tokens': 1024,
'd_model': 1024, # 1536
'n_heads': 16, # 24
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32
'n_experts': 1,
}
"""
kwargs = {
'n_tokens': 1024,
'd_model': 256,
'n_heads': 4,
'n_layers': 12,
'n_experts': 8,
}
"""
"""
try:
@ -326,11 +363,38 @@ def example_usage():
"""
model = AR_NAR(**kwargs).to(device)
steps = 250
steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
engine = Engine(model=model, optimizer=optimizer)
# copy embeddings if requested
if cfg.models._embeddings is not None:
embeddings_path = cfg.relpath / cfg.models._embeddings
if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
if "module" in embeddings:
embeddings = embeddings["module"]
frozen_params = set()
for k in list(embeddings.keys()):
if re.findall(r'_emb\.', k):
frozen_params.add(k)
else:
del embeddings[k]
engine.module.load_state_dict(embeddings, strict=False)
for name, param in engine.module.named_parameters():
if name not in frozen_params:
continue
param.requires_grad_(False)
engine._frozen_params.add(param)
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
model.model = ml.replace_linear( model.model )
torch.save( {
'module': model.state_dict()
}, "./data/test.pth" )
@ -359,9 +423,14 @@ def example_usage():
for i in t:
stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}")
torch.save( {
'module': model.state_dict()
}, "./data/test.pth" )
sample("init", 5)
train()
sample("final")

@ -3,6 +3,7 @@ import torch
import torch.nn.functional as F
import traceback
import numpy as np
import re
from typing import Literal, overload
from functools import partial
@ -14,10 +15,112 @@ from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from .retnet import RetNetDecoder, RetNetConfig
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
try:
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
except Exception as e:
print("Error importing `transformer` arch:", e)
pass
try:
from .retnet import RetNetDecoder, RetNetConfig
except Exception as e:
print("Error importing `retnet` arch:", e)
pass
try:
from transformers import LlamaModel, LlamaConfig
except Exception as e:
print("Error importing `llama` arch:", e)
pass
try:
from transformers import MistralModel, MistralConfig
except Exception as e:
print("Error importing `mistral` arch:", e)
pass
try:
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
class BitNetTransformer(nn.Module):
def __init__(
self,
dim: int,
depth: int,
num_tokens: int,
heads=8,
ff_mult=4,
):
super().__init__()
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
self.norm = BitNetRMSNorm(dim)
def forward(self, x):
x = self.transformer(x)
return self.norm( x )
"""
from bitnet import BitNetTransformer
def NoEmbedding_BitNetTransformer_Forward(self, x):
x = self.transformer(x)
return self.to_logits[0](x)
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
"""
except Exception as e:
print("Error importing `bitnet` arch:", e)
pass
try:
from transformers import MixtralModel, MixtralConfig
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
# This is required because batch sizes > 1 throws errors
def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_dim) # was view()
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
if top_x.shape[0] == 0:
continue
top_x_list = top_x.tolist()
idx_list = idx.tolist()
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
Original_MixtralSparseMoeBlock_forward = MixtralSparseMoeBlock.forward
MixtralSparseMoeBlock.forward = Fixed_MixtralSparseMoeBlock_forward
except Exception as e:
print("Error importing `mixtral` arch:", e)
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -101,9 +204,10 @@ class MultiEmbedding(nn.Module):
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
class AudioEmbedding(nn.Module):
def __init__(self, l_tokens, token_dim):
def __init__(self, l_tokens, token_dim, levels=None):
super().__init__()
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
res_list = []
@ -111,13 +215,13 @@ class AudioEmbedding(nn.Module):
for i, xi in enumerate(x_list):
# prom
if quant_levels is None and xi.shape[-1] > 1:
x = sum( [ self.embeddings[k]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
# AR resp
elif quant_levels is None or quant_levels[i] == 0:
x = self.embeddings[0]( xi[:, 0] )
# NAR resp
else:
x = sum( [ self.embeddings[k+1]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
res_list.append(x)
return res_list
@ -204,9 +308,13 @@ class Base(nn.Module):
n_layers: int = 12,
p_dropout: float = 0.1,
n_experts: int=1,
training = True,
config = None,
):
super().__init__()
self.training = training
self.config = config
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
@ -214,6 +322,7 @@ class Base(nn.Module):
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.n_experts = n_experts
# +1 to include the stop token
# to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding
@ -230,9 +339,9 @@ class Base(nn.Module):
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
else:
# [1024] * 8
self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model)
self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model, self.n_prom_levels if self.version > 3 else None)
# [1025] + [1024] * 8
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model)
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, self.n_resp_levels if self.version > 3 else None)
if self.version >= 3:
@ -246,20 +355,88 @@ class Base(nn.Module):
self.blocks = nn.ModuleList([TransformerBlock(
d_model=d_model,
n_heads=n_heads,
p_dropout=p_dropout,
p_dropout=p_dropout if training else 0.0,
causal=self.causal,
norm_type=self.norm_type,
n_levels=self.n_resp_levels,
) for _ in range(n_layers) ])
elif self.arch_type == "mistral" or self.arch_type == "mixtral":
if n_experts <= 1:
self.model = MistralModel(MistralConfig(
vocab_size=n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
))
else:
self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
sliding_window=75 * 12, # 12 second context window
output_router_logits=training,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
))
elif self.arch_type == "llama":
if n_experts <= 1:
self.model = LlamaModel(LlamaConfig(
vocab_size=n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
sliding_window=75 * 12, # 12 second context window
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
))
else:
self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
sliding_window=75 * 12, # 12 second context window
output_router_logits=training,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
))
elif self.arch_type == "retnet":
self.retnet = RetNetDecoder(RetNetConfig(
vocab_size=n_tokens,
kwargs = dict(
vocab_size=n_resp_tokens,
decoder_embed_dim=d_model,
decoder_value_embed_dim =d_model * 2,
decoder_retention_heads=n_heads,
decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers,
dropout=p_dropout,
dropout=p_dropout if training else 0.0,
checkpoint_activations=self.activation_checkpointing,
activation_fn="gelu",
use_layernorm=True, # self.version < 3,
@ -272,7 +449,27 @@ class Base(nn.Module):
decoder_normalize_before=True,
rotary_embedding_base=self.rotary_embedding_base, # 10000
))
)
if n_experts > 1:
kwargs.update(dict(
use_xmoe=True,
moe_freq=1,
moe_expert_count=n_experts,
moe_gating_use_fp32=False,
))
self.model = RetNetDecoder(RetNetConfig(**kwargs))
elif self.arch_type == "bitnet":
self.model = BitNetTransformer(
num_tokens=n_resp_tokens,
dim=d_model,
depth=n_layers,
heads=n_heads,
ff_mult=4,
)
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
self.classifier = nn.Linear(d_model, n_resp_tokens)
@ -302,7 +499,7 @@ class Base(nn.Module):
lang_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
state: dict | None = None,
state: dict | list | None = None,
):
batch_size = len(text_list)
@ -318,6 +515,7 @@ class Base(nn.Module):
)
x, m = list_to_tensor(x_list)
aux_loss = None
device = x.device
@ -328,12 +526,33 @@ class Base(nn.Module):
# run the initial prompt to fill the KV cache
for n in range(prefill_size):
xi = x[:, n, :].unsqueeze(1)
self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True)
self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True)
# grab last token(s)
x = x[:, -1, :].unsqueeze(1)
# HF transformer derived model
elif self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral":
kwargs = dict(
#attention_mask=m,
inputs_embeds=x,
past_key_values=state,
use_cache=state is not None,
# return_dict=True,
)
if self.n_experts > 1 and targ_list is not None:
kwargs["output_router_logits"] = True
if self.arch_type == "transformer":
t = self.model(**kwargs)
x = t[0]
if state is not None:
state = t[1]
if self.n_experts > 1 and targ_list is not None:
router_logits = t[-1]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
elif self.arch_type == "transformer":
# ensures we specify a quant_level for the transformer implementation's AdaLN
l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels
l = l.to(device)
@ -344,8 +563,11 @@ class Base(nn.Module):
x = block(x, m, l)
elif self.arch_type == "retnet":
# pass our inputs through the RetNet
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
if _ is not None and "l_aux" in _ and self.n_experts > 1:
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
elif self.arch_type == "bitnet":
x = self.model(x)
# output projection layer with masking
x = self.classifier(x) * m
@ -354,7 +576,6 @@ class Base(nn.Module):
# compute loss if the target is given
if targ_list is not None:
target_list = self._samplewise_merge_tensors(
text_list,
lang_list,
@ -380,10 +601,13 @@ class Base(nn.Module):
)
self.stats = dict(
acc = self.accuracy_metric( inputs, target ),
precision = self.precision_metric( inputs, target ),
# precision = self.precision_metric( inputs, target ),
)
if aux_loss is not None:
self.loss["nll"] += aux_loss
return logits
return (logits, state) if state is not None else logits
def sample(
self,
@ -432,7 +656,7 @@ class Base(nn.Module):
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
# epsilon float comparison because I don't trust Python
# epsilon float comparison because I don't trust Python
if abs(temperature - min_temperature) >= 0.001:
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
else:

@ -80,8 +80,8 @@ class Attention(nn.Module):
self.n_heads = n_heads
self.scale = dim_head**-0.5
self.to_qkv = nn.Linear(d_model, d_model * 3, bias=False)
self.to_out = nn.Linear(d_model, d_model)
self.to_qkv = ml.Linear(d_model, d_model * 3, bias=False)
self.to_out = ml.Linear(d_model, d_model)
def forward(self, x, m):
"""
@ -169,10 +169,10 @@ class Block(nn.Sequential):
n_ff = d_model * 4 # 1024 * 4 = 4096 feed-forwards
self.ffn = PrenormResidual(
nn.Sequential(
nn.Linear(d_model, n_ff),
ml.Linear(d_model, n_ff),
nn.GELU(),
nn.Dropout(p_dropout),
nn.Linear(n_ff, d_model),
ml.Linear(n_ff, d_model),
),
d_model=d_model,
p_dropout=p_dropout,

@ -146,7 +146,10 @@ def run_eval(engines, eval_name, dl):
}
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
if cfg.trainer.no_logger:
tqdm.write(f"Validation Metrics: {json.dumps(engines_stats)}.")
else:
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
def train():

@ -16,6 +16,7 @@ def get_free_port():
_distributed_initialized = False
def init_distributed( fn, *args, **kwargs ):
#print("Initializing distributed...")
fn(*args, **kwargs)
_distributed_initialized = True

@ -30,7 +30,7 @@ from .distributed import (
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder, load_engines
from .utils import to_device, do_gc
from .utils import to_device, do_gc, truncate_json
from ..utils import wrapper as ml
from ..data import get_phone_symmap # should decouple from this trainer script
@ -174,7 +174,11 @@ def train(
elapsed_time = stats.get("elapsed_time", 0)
metrics = json.dumps(stats)
_logger.info(f"Training Metrics: {metrics}.")
if cfg.trainer.no_logger:
tqdm.write(f"Training Metrics: {truncate_json(metrics)}.")
else:
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
command = _non_blocking_input()

@ -19,6 +19,13 @@ from typing import Callable, TypeVar, overload
T = TypeVar("T")
def truncate_json( str ):
def fun( match ):
return "{:.4f}".format(float(match.group()))
return re.sub(r"\d+\.\d{8,}", fun, str)
def do_gc():
gc.collect()
torch.cuda.empty_cache()

@ -7,11 +7,19 @@ from ..config import cfg
Embedding = torch.nn.Embedding
Linear = torch.nn.Linear
# https://github.com/kyegomez/BitNet
if cfg.bitsandbytes.bitnet:
from bitnet import BitLinear
if cfg.bitsandbytes.enabled:
import bitsandbytes as bnb
if cfg.bitsandbytes.linear:
Linear = bnb.nn.Linear8bitLt
if cfg.bitsandbytes.bitnet:
Linear = BitLinear
else:
Linear = bnb.nn.Linear8bitLt
if cfg.bitsandbytes.embedding:
Embedding = bnb.nn.modules.Embedding
@ -75,6 +83,28 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
torch.optim.AdamW = AdamW
torch.optim.SGD = SGD
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
def replace_linear( model ):
device = next(model.parameters()).device
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
for *parent, k in linears:
name = '.'.join(parent)
# copy parameters
m = getattr( model.get_submodule(name), k )
in_features = m.in_features
out_features = m.out_features
bias = m.bias is not None
# overwrite
setattr(
model.get_submodule(name), k,
Linear( in_features=in_features, out_features=out_features, bias=bias )
)
return model.to(device) # because our now Linear is created on the CPU......
# https://github.com/konstmish/prodigy
try:
from prodigyopt import Prodigy

@ -77,6 +77,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
# I'm very sure I can procedurally generate this list
parser.add_argument("--text", type=str, default=kwargs["text"])
parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75))
parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*75))
@ -104,6 +105,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
with timer() as t:
wav, sr = tts.inference(
text=args.text,
language=args.language,
references=[args.references.split(";")],
out_path=tmp.name,
max_ar_steps=args.max_ar_steps,