sanity cleanup: moved experimental features under its own thing

This commit is contained in:
mrq 2024-06-30 10:37:33 -05:00
parent b21f74a5c5
commit bc2a6fa756
12 changed files with 80 additions and 106 deletions

View File

@ -23,7 +23,8 @@ I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
## Pre-Trained Model
> [!NOTE] Pre-Trained weights aren't up to par as a pure zero-shot model at the moment, but are fine for finetuning / LoRAs.
> [!NOTE]
> Pre-Trained weights aren't up to par as a pure zero-shot model at the moment, but are fine for finetuning / LoRAs.
My pre-trained weights can be acquired from [here](https://huggingface.co/ecker/vall-e).
@ -70,7 +71,7 @@ If you already have a dataset you want, for example, your own large corpus or fo
Two dataset formats are supported:
* the standard way:
- dta is stored under `./training/data/{group}/{speaker}/{id}.{enc|dac}` as a NumPy file, where `enc` is for the EnCodec/Vocos backend, and `dac` for the Descript-Audio-Codec backend.
- data is stored under `./training/data/{group}/{speaker}/{id}.{enc|dac}` as a NumPy file, where `enc` is for the EnCodec/Vocos backend, and `dac` for the Descript-Audio-Codec backend.
- it is *highly* recommended to generate metadata to speed up dataset pre-load with `python3 -m vall_e.data --yaml="./training/config.yaml" --action=metadata`
* using an HDF5 dataset:
- you can convert from the standard way with the following command: `python3 -m vall_e.data --yaml="./training/config.yaml"` (metadata for dataset pre-load is generated alongside HDF5 creation)
@ -105,7 +106,7 @@ loras:
training: True
```
And thats it. Training of the LoRA is done with the same command. Depending on the rank and alpha specified, the loss may be higher than it should, as the LoRA weights are initialized to appropriately random values. I found rank and alpha of 128 works fine.
And that's it. Training of the LoRA is done with the same command. Depending on the rank and alpha specified, the loss may be higher than it should, as the LoRA weights are initialized to appropriately random values. I found `rank` and `alpha` of 128 works fine.
To export your LoRA weights, run `python3 -m vall_e.export --lora --yaml="./training/config.yaml"`. You *should* be able to have the LoRA weights loaded from a training checkpoint automagically for inferencing, but export them just to be safe.
@ -125,7 +126,7 @@ Creature comforts like `float16`, `amp`, and multi-GPU training *should* work un
#### Backend Architectures
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLm architectures:
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLM architectures:
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
+ I aim to utilize this for the foundational model, as I get to leverage a bunch of things tailored for LLaMA (and converting to them is rather easy).

View File

@ -17,7 +17,6 @@ def main():
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
parser.add_argument("--max-nar-levels", type=int, default=7)
parser.add_argument("--max-ar-context", type=int, default=-1)
parser.add_argument("--ar-temp", type=float, default=1.0)
parser.add_argument("--nar-temp", type=float, default=0.01)
@ -50,7 +49,6 @@ def main():
out_path=args.out_path,
input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
max_ar_context=args.max_ar_context,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
top_p=args.top_p, top_k=args.top_k,

View File

@ -195,6 +195,15 @@ class Dataset:
def max_duration(self):
return self.duration_range[1]
@dataclass()
class ModelExperimentalSettings:
hf: bool = False # strictly utilizes a HF model and handles converting input IDs / outputs accordingly
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
kv_heads: int = 0 # MHA or GQA (for supported backends)
# I really need to clean this up
@dataclass()
class Model:
@ -209,21 +218,17 @@ class Model:
experts: int = 1 # for mixtral / retnet-ts
arch_type: str = "llama" # underling LM architecture used
training: bool = True # I really need to attend to this
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
attention: str = "auto" # for llama arch_types: attention used
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
split_classifiers: bool = False # experimental, but each RVQ level gets its own classifier / output proj / LM head
dropout: float = 0.1 # adjustable dropout value
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
loss_factors: dict = field(default_factory=lambda: {})
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
experimental: str | None = None # for now it sets things to be HF compatible
kv_heads: int = 0 # MHA or GQA (for supported backends)
audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
experimental: dict | ModelExperimentalSettings | None = None # experimental settings
def get(self, name=None):
return [ self ] if not name or self.name == name else []
@ -758,16 +763,27 @@ class Config(BaseConfig):
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
"""
if self.models is not None:
self.model = Model(**next(iter(self.models)))
else:
self.model = Model(**self.model)
"""
for model in self.models:
if not isinstance( model, dict ):
continue
if "audio_embedding_sums" not in model:
continue
if not model["experimental"]:
model["experimental"] = {}
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
self.models = [ Model(**model) for model in self.models ]
self.loras = [ LoRA(**lora) for lora in self.loras ]
for model in self.models:
if not isinstance( model.experimental, dict ):
continue
model.experimental = ModelExperimentalSettings(**model.experimental)
self.hyperparameters = Hyperparameters(**self.hyperparameters)
self.evaluation = Evaluation(**self.evaluation)

View File

@ -761,7 +761,8 @@ class Dataset(_Dataset):
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
# append additional prompts in an attempt to artifically increase lengths / offer new data
if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
"""
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})]
if len(choices) > 0:
@ -790,6 +791,7 @@ class Dataset(_Dataset):
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
resps = torch.concat([ resps, qnt ])
"""
task = "tts"
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)

View File

@ -119,7 +119,6 @@ class TTS():
references,
language="en",
max_ar_steps=6 * cfg.dataset.frames_per_second,
max_ar_context=-1,
max_nar_levels=7,
input_prompt_length=0.0,
ar_temp=0.95,
@ -182,7 +181,7 @@ class TTS():
# prom size: 3
if model_ar is not None:
resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,

View File

@ -19,7 +19,7 @@ def get_model(config, training=True):
training = training,
config = config,
)
elif config.experimental:
elif config.experimental.hf:
from .experimental import Model as Experimental
model = Experimental(
n_text_tokens=config.text_tokens,

View File

@ -98,20 +98,6 @@ class AR_NAR(Base):
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
return 1 # if self.causal else 0
@property
def interleave(self) -> bool:
return False
@property
def monolithic(self) -> bool:
return True
@property
def audio_embeddings_mode(self) -> bool:
if hasattr(self, "config") and self.config:
return self.config.audio_embeddings_mode
return cfg.model.audio_embeddings_mode
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
@ -144,7 +130,6 @@ class AR_NAR(Base):
max_steps: int = 1000,
max_levels: int = 0,
max_resp_context: int = -1,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
@ -305,17 +290,9 @@ class AR_NAR(Base):
scores = [ 1.0 ] * sampling_beam_width
if self.interleave:
max_steps *= self.n_prom_levels
# get next in sequence
for n in trange(max_steps // max(1, self.causal_size), desc="AR"):
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this.
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
if max_resp_context > 0:
resps_list = self._unsqueeze_list([ sequence[-max_resp_context:] for sequence in sequence_list ] )
else:
resps_list = self._unsqueeze_list(sequence_list)
resps_list = self._unsqueeze_list(sequence_list)
inputs = self.inputs(
text_list=text_list,

View File

@ -278,7 +278,6 @@ class Metrics(nn.Module):
class Base(nn.Module):
# to-do: clean up this property mess
@property
def causal(self) -> bool:
raise NotImplementedError
@ -319,21 +318,9 @@ class Base(nn.Module):
def causal_size(self) -> int:
raise NotImplementedError
@property
def interleave(self) -> bool:
return False
@property
def monolithic(self) -> bool:
return False
@property
def audio_embeddings_mode(self) -> str | None:
return None
@property
def version(self) -> int:
return 1
return 2
@property
def capabilities(self) -> list[str]:
@ -403,8 +390,9 @@ class Base(nn.Module):
n_resp_tokens = n_audio_tokens
l_tokens = [n_resp_tokens] * self.n_resp_levels
audio_embedding_sums = self.config.audio_embedding_sums if self.config is not None else True
split_classifiers = self.config.split_classifiers if self.config is not None else True
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
audio_embeddings_mode = self.config.experimental.audio_embeddings_mode if self.config is not None else ""
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
@ -432,12 +420,12 @@ class Base(nn.Module):
self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model,
sums=audio_embedding_sums,
external_mode=self.audio_embeddings_mode,
external_mode=audio_embeddings_mode,
)
self.resps_emb = AudioEmbedding(
l_tokens, d_model,
sums=audio_embedding_sums,
external_mode=self.audio_embeddings_mode,
external_mode=audio_embeddings_mode,
)
# useless since I actually removed using these with the input processing overhaul...
@ -476,7 +464,6 @@ class Base(nn.Module):
if self.config.attention not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
if self.arch_type == "transformer":
self.sin_emb = SinusoidalEmbedding(d_model)
self.blocks = nn.ModuleList([TransformerBlock(
@ -497,7 +484,7 @@ class Base(nn.Module):
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
num_key_value_heads=self.config.experimental.kv_heads if self.config is not None and self.config.experimental.kv_heads > 0 else n_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
@ -513,7 +500,7 @@ class Base(nn.Module):
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
num_key_value_heads=self.config.experimental.kv_heads if self.config is not None and self.config.experimental.kv_heads > 0 else n_heads,
sliding_window=75 * 12, # 12 second context window
output_router_logits=training,
hidden_act="gelu",
@ -529,9 +516,6 @@ class Base(nn.Module):
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
#if training:
# self.model.training = True
elif self.arch_type == "llama":
if n_experts <= 1:
self.model = LlamaModel(LlamaConfig(
@ -575,9 +559,6 @@ class Base(nn.Module):
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
#if training:
# self.model.training = True
elif self.arch_type == "retnet":
kwargs = dict(
vocab_size=n_resp_tokens,
@ -695,9 +676,6 @@ class Base(nn.Module):
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
#if training:
# self.model.training = True
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
@ -970,7 +948,7 @@ class Base(nn.Module):
task_list.append( input )
elif name == "prom":
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.config.audio_embedding_sums):
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
target.append( torch.full_like(input[..., 0], self.ignore_index) )
# we *CAN* directly map to proms
else:

View File

@ -251,7 +251,7 @@ def example_usage():
kwargs = {}
model = Model(**kwargs).to(device)
steps = 100
steps = 100 if cfg.model.experimental.interleave else 300
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
@ -324,7 +324,7 @@ def example_usage():
engine.eval()
batch_size = len(text_list)
resp_list = None
if cfg.model.interleave:
if cfg.model.experimental.interleave:
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list)
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
@ -382,7 +382,7 @@ def example_usage():
stats = {"step": i}
batch_size = len(text_list)
quant_levels = None if cfg.model.interleave else torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
quant_levels = None if cfg.model.experimental.interleave else torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
if quant_levels is not None:
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
else:

View File

@ -27,6 +27,12 @@ class NAR(Base):
return self.config.capabilities
return cfg.model.capabilities
@property
def quant_level_range(self) -> list[int]:
if hasattr(self, "config") and self.config.rvq_level_range:
return self.config.rvq_level_range
return [ 0 if self.causal else 1, self.n_resp_levels ]
@property
def causal(self):
return "len" in self.capabilities
@ -64,6 +70,12 @@ class NAR(Base):
if hasattr(self, "config") and self.config:
return self.config.tasks
return cfg.model.tasks
@property
def p_rvq_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.p_rvq_levels
return cfg.model.p_rvq_levels
@property
def n_langs(self) -> int:
@ -84,14 +96,6 @@ class NAR(Base):
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
return 1 # if self.causal else 0
@property
def interleave(self) -> bool:
return False
@property
def monolithic(self) -> bool:
return True
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
@ -155,9 +159,12 @@ class NAR(Base):
task_list = [ sample_task() for _ in range(batch_size) ]
# determines which RVQ level to target per batch
quant_level_range = [ 0 if self.causal else 1, self.n_resp_levels ]
quant_level_range = self.quant_level_range
if cfg.experimental:
if self.p_rvq_levels == "equal":
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
else: # if self.p_rvq_levels == "auto":
# makes higher levels less likely
def generate( lo=0, hi=8 ):
index = lo
@ -167,18 +174,16 @@ class NAR(Base):
index = i
return int(index)
quant_levels = [ 0 if task_list[i] == "len" else generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
else:
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
for i, prom in enumerate(proms_list):
if quant_levels[i] + 1 > prom.shape[-1]:
quant_levels[i] = prom.shape[-1] - 1
for i, resp in enumerate(resps_list):
if quant_levels[i] + 1 > resp.shape[-1]:
quant_levels[i] = resp.shape[-1] - 1
for i in range(batch_size):
# cap quant_level if it exceeds its corresponding resp/prom
if quant_levels[i] >= resps_list[i].shape[-1]:
quant_levels[i] = resps_list[i].shape[-1] - 1
if quant_levels[i] >= proms_list[i].shape[-1]:
quant_levels[i] = proms_list[i].shape[-1] - 1
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]

View File

@ -31,8 +31,8 @@ def train_feeder(engine, batch):
batch_size = len(batch["text"])
engine.current_batch_size = batch_size
if engine.hyper_config.experimental:
if cfg.model.interleave:
if engine.hyper_config.experimental.hf:
if engine.hyper_config.experimental.interleave:
quant_levels = 0
resps_list = [ resp for resp in batch["resps"] ]
else:
@ -129,8 +129,8 @@ def run_eval(engines, eval_name, dl):
for name in engines:
engine = engines[name]
if engine.hyper_config.experimental:
if cfg.model.interleave:
if engine.hyper_config.experimental.hf:
if engine.hyper_config.experimental.interleave:
input_ids, attention_mask = fold_inputs(
text_list=batch["text"],
prom_list=batch["proms"],

View File

@ -78,7 +78,6 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
@ -211,7 +210,6 @@ with ui:
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
layout["inference"]["inputs"]["max-seconds-context"] = gr.Slider(value=0.0, minimum=0.0, maximum=12.0, step=0.05, label="Context Length", info="Amount of generated audio to keep in the context during inference, in seconds. Set 0 to disable.")
with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")