sanity cleanup: moved experimental features under its own thing

This commit is contained in:
mrq 2024-06-30 10:37:33 -05:00
parent b21f74a5c5
commit bc2a6fa756
12 changed files with 80 additions and 106 deletions

View File

@ -23,7 +23,8 @@ I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
## Pre-Trained Model ## Pre-Trained Model
> [!NOTE] Pre-Trained weights aren't up to par as a pure zero-shot model at the moment, but are fine for finetuning / LoRAs. > [!NOTE]
> Pre-Trained weights aren't up to par as a pure zero-shot model at the moment, but are fine for finetuning / LoRAs.
My pre-trained weights can be acquired from [here](https://huggingface.co/ecker/vall-e). My pre-trained weights can be acquired from [here](https://huggingface.co/ecker/vall-e).
@ -70,7 +71,7 @@ If you already have a dataset you want, for example, your own large corpus or fo
Two dataset formats are supported: Two dataset formats are supported:
* the standard way: * the standard way:
- dta is stored under `./training/data/{group}/{speaker}/{id}.{enc|dac}` as a NumPy file, where `enc` is for the EnCodec/Vocos backend, and `dac` for the Descript-Audio-Codec backend. - data is stored under `./training/data/{group}/{speaker}/{id}.{enc|dac}` as a NumPy file, where `enc` is for the EnCodec/Vocos backend, and `dac` for the Descript-Audio-Codec backend.
- it is *highly* recommended to generate metadata to speed up dataset pre-load with `python3 -m vall_e.data --yaml="./training/config.yaml" --action=metadata` - it is *highly* recommended to generate metadata to speed up dataset pre-load with `python3 -m vall_e.data --yaml="./training/config.yaml" --action=metadata`
* using an HDF5 dataset: * using an HDF5 dataset:
- you can convert from the standard way with the following command: `python3 -m vall_e.data --yaml="./training/config.yaml"` (metadata for dataset pre-load is generated alongside HDF5 creation) - you can convert from the standard way with the following command: `python3 -m vall_e.data --yaml="./training/config.yaml"` (metadata for dataset pre-load is generated alongside HDF5 creation)
@ -105,7 +106,7 @@ loras:
training: True training: True
``` ```
And thats it. Training of the LoRA is done with the same command. Depending on the rank and alpha specified, the loss may be higher than it should, as the LoRA weights are initialized to appropriately random values. I found rank and alpha of 128 works fine. And that's it. Training of the LoRA is done with the same command. Depending on the rank and alpha specified, the loss may be higher than it should, as the LoRA weights are initialized to appropriately random values. I found `rank` and `alpha` of 128 works fine.
To export your LoRA weights, run `python3 -m vall_e.export --lora --yaml="./training/config.yaml"`. You *should* be able to have the LoRA weights loaded from a training checkpoint automagically for inferencing, but export them just to be safe. To export your LoRA weights, run `python3 -m vall_e.export --lora --yaml="./training/config.yaml"`. You *should* be able to have the LoRA weights loaded from a training checkpoint automagically for inferencing, but export them just to be safe.
@ -125,7 +126,7 @@ Creature comforts like `float16`, `amp`, and multi-GPU training *should* work un
#### Backend Architectures #### Backend Architectures
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLm architectures: As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLM architectures:
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements. * `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
+ I aim to utilize this for the foundational model, as I get to leverage a bunch of things tailored for LLaMA (and converting to them is rather easy). + I aim to utilize this for the foundational model, as I get to leverage a bunch of things tailored for LLaMA (and converting to them is rather easy).

View File

@ -17,7 +17,6 @@ def main():
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second) parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
parser.add_argument("--max-nar-levels", type=int, default=7) parser.add_argument("--max-nar-levels", type=int, default=7)
parser.add_argument("--max-ar-context", type=int, default=-1)
parser.add_argument("--ar-temp", type=float, default=1.0) parser.add_argument("--ar-temp", type=float, default=1.0)
parser.add_argument("--nar-temp", type=float, default=0.01) parser.add_argument("--nar-temp", type=float, default=0.01)
@ -50,7 +49,6 @@ def main():
out_path=args.out_path, out_path=args.out_path,
input_prompt_length=args.input_prompt_length, input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
max_ar_context=args.max_ar_context,
ar_temp=args.ar_temp, nar_temp=args.nar_temp, ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
top_p=args.top_p, top_k=args.top_k, top_p=args.top_p, top_k=args.top_k,

View File

@ -195,6 +195,15 @@ class Dataset:
def max_duration(self): def max_duration(self):
return self.duration_range[1] return self.duration_range[1]
@dataclass()
class ModelExperimentalSettings:
hf: bool = False # strictly utilizes a HF model and handles converting input IDs / outputs accordingly
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
kv_heads: int = 0 # MHA or GQA (for supported backends)
# I really need to clean this up # I really need to clean this up
@dataclass() @dataclass()
class Model: class Model:
@ -209,22 +218,18 @@ class Model:
experts: int = 1 # for mixtral / retnet-ts experts: int = 1 # for mixtral / retnet-ts
arch_type: str = "llama" # underling LM architecture used arch_type: str = "llama" # underling LM architecture used
training: bool = True # I really need to attend to this training: bool = True # I really need to attend to this
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
attention: str = "auto" # for llama arch_types: attention used attention: str = "auto" # for llama arch_types: attention used
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
split_classifiers: bool = False # experimental, but each RVQ level gets its own classifier / output proj / LM head
dropout: float = 0.1 # adjustable dropout value dropout: float = 0.1 # adjustable dropout value
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good #loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
loss_factors: dict = field(default_factory=lambda: {}) loss_factors: dict = field(default_factory=lambda: {})
capabilities: list = field(default_factory=lambda: ["ar", "nar"]) capabilities: list = field(default_factory=lambda: ["ar", "nar"])
experimental: str | None = None # for now it sets things to be HF compatible
kv_heads: int = 0 # MHA or GQA (for supported backends)
audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
experimental: dict | ModelExperimentalSettings | None = None # experimental settings
def get(self, name=None): def get(self, name=None):
return [ self ] if not name or self.name == name else [] return [ self ] if not name or self.name == name else []
@ -758,16 +763,27 @@ class Config(BaseConfig):
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ] self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ] self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
""" for model in self.models:
if self.models is not None: if not isinstance( model, dict ):
self.model = Model(**next(iter(self.models))) continue
else:
self.model = Model(**self.model) if "audio_embedding_sums" not in model:
""" continue
if not model["experimental"]:
model["experimental"] = {}
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
self.models = [ Model(**model) for model in self.models ] self.models = [ Model(**model) for model in self.models ]
self.loras = [ LoRA(**lora) for lora in self.loras ] self.loras = [ LoRA(**lora) for lora in self.loras ]
for model in self.models:
if not isinstance( model.experimental, dict ):
continue
model.experimental = ModelExperimentalSettings(**model.experimental)
self.hyperparameters = Hyperparameters(**self.hyperparameters) self.hyperparameters = Hyperparameters(**self.hyperparameters)
self.evaluation = Evaluation(**self.evaluation) self.evaluation = Evaluation(**self.evaluation)

View File

@ -761,7 +761,8 @@ class Dataset(_Dataset):
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8) lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
# append additional prompts in an attempt to artifically increase lengths / offer new data # append additional prompts in an attempt to artifically increase lengths / offer new data
if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append: """
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})] choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})]
if len(choices) > 0: if len(choices) > 0:
@ -790,6 +791,7 @@ class Dataset(_Dataset):
# might be better to decode => concat waveforms with silence in between => reencode # might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues # as you technically can't just append encodec sequences together like this without issues
resps = torch.concat([ resps, qnt ]) resps = torch.concat([ resps, qnt ])
"""
task = "tts" task = "tts"
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)

View File

@ -119,7 +119,6 @@ class TTS():
references, references,
language="en", language="en",
max_ar_steps=6 * cfg.dataset.frames_per_second, max_ar_steps=6 * cfg.dataset.frames_per_second,
max_ar_context=-1,
max_nar_levels=7, max_nar_levels=7,
input_prompt_length=0.0, input_prompt_length=0.0,
ar_temp=0.95, ar_temp=0.95,
@ -182,7 +181,7 @@ class TTS():
# prom size: 3 # prom size: 3
if model_ar is not None: if model_ar is not None:
resps_list = model_ar( resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context, text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps,
sampling_temperature=ar_temp, sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp, sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k, sampling_top_p=top_p, sampling_top_k=top_k,

View File

@ -19,7 +19,7 @@ def get_model(config, training=True):
training = training, training = training,
config = config, config = config,
) )
elif config.experimental: elif config.experimental.hf:
from .experimental import Model as Experimental from .experimental import Model as Experimental
model = Experimental( model = Experimental(
n_text_tokens=config.text_tokens, n_text_tokens=config.text_tokens,

View File

@ -98,20 +98,6 @@ class AR_NAR(Base):
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it # could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
return 1 # if self.causal else 0 return 1 # if self.causal else 0
@property
def interleave(self) -> bool:
return False
@property
def monolithic(self) -> bool:
return True
@property
def audio_embeddings_mode(self) -> bool:
if hasattr(self, "config") and self.config:
return self.config.audio_embeddings_mode
return cfg.model.audio_embeddings_mode
@property @property
def version(self) -> int: def version(self) -> int:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
@ -144,7 +130,6 @@ class AR_NAR(Base):
max_steps: int = 1000, max_steps: int = 1000,
max_levels: int = 0, max_levels: int = 0,
max_resp_context: int = -1,
sampling_temperature: float = 1.0, sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0, sampling_min_temperature: float = -1.0,
@ -305,17 +290,9 @@ class AR_NAR(Base):
scores = [ 1.0 ] * sampling_beam_width scores = [ 1.0 ] * sampling_beam_width
if self.interleave:
max_steps *= self.n_prom_levels
# get next in sequence # get next in sequence
for n in trange(max_steps // max(1, self.causal_size), desc="AR"): for n in trange(max_steps // max(1, self.causal_size), desc="AR"):
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this. resps_list = self._unsqueeze_list(sequence_list)
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
if max_resp_context > 0:
resps_list = self._unsqueeze_list([ sequence[-max_resp_context:] for sequence in sequence_list ] )
else:
resps_list = self._unsqueeze_list(sequence_list)
inputs = self.inputs( inputs = self.inputs(
text_list=text_list, text_list=text_list,

View File

@ -278,7 +278,6 @@ class Metrics(nn.Module):
class Base(nn.Module): class Base(nn.Module):
# to-do: clean up this property mess # to-do: clean up this property mess
@property @property
def causal(self) -> bool: def causal(self) -> bool:
raise NotImplementedError raise NotImplementedError
@ -319,21 +318,9 @@ class Base(nn.Module):
def causal_size(self) -> int: def causal_size(self) -> int:
raise NotImplementedError raise NotImplementedError
@property
def interleave(self) -> bool:
return False
@property
def monolithic(self) -> bool:
return False
@property
def audio_embeddings_mode(self) -> str | None:
return None
@property @property
def version(self) -> int: def version(self) -> int:
return 1 return 2
@property @property
def capabilities(self) -> list[str]: def capabilities(self) -> list[str]:
@ -403,8 +390,9 @@ class Base(nn.Module):
n_resp_tokens = n_audio_tokens n_resp_tokens = n_audio_tokens
l_tokens = [n_resp_tokens] * self.n_resp_levels l_tokens = [n_resp_tokens] * self.n_resp_levels
audio_embedding_sums = self.config.audio_embedding_sums if self.config is not None else True audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
split_classifiers = self.config.split_classifiers if self.config is not None else True split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
audio_embeddings_mode = self.config.experimental.audio_embeddings_mode if self.config is not None else ""
self.text_emb = Embedding(n_text_tokens, d_model) self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None self.langs_emb = None
@ -432,12 +420,12 @@ class Base(nn.Module):
self.proms_emb = AudioEmbedding( self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model, [n_prom_tokens] * self.n_prom_levels, d_model,
sums=audio_embedding_sums, sums=audio_embedding_sums,
external_mode=self.audio_embeddings_mode, external_mode=audio_embeddings_mode,
) )
self.resps_emb = AudioEmbedding( self.resps_emb = AudioEmbedding(
l_tokens, d_model, l_tokens, d_model,
sums=audio_embedding_sums, sums=audio_embedding_sums,
external_mode=self.audio_embeddings_mode, external_mode=audio_embeddings_mode,
) )
# useless since I actually removed using these with the input processing overhaul... # useless since I actually removed using these with the input processing overhaul...
@ -476,7 +464,6 @@ class Base(nn.Module):
if self.config.attention not in AVAILABLE_ATTENTIONS: if self.config.attention not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
if self.arch_type == "transformer": if self.arch_type == "transformer":
self.sin_emb = SinusoidalEmbedding(d_model) self.sin_emb = SinusoidalEmbedding(d_model)
self.blocks = nn.ModuleList([TransformerBlock( self.blocks = nn.ModuleList([TransformerBlock(
@ -497,7 +484,7 @@ class Base(nn.Module):
num_hidden_layers=n_layers, num_hidden_layers=n_layers,
num_attention_heads=n_heads, num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0, attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads, num_key_value_heads=self.config.experimental.kv_heads if self.config is not None and self.config.experimental.kv_heads > 0 else n_heads,
hidden_act="gelu", hidden_act="gelu",
is_encoder_decoder=False, is_encoder_decoder=False,
is_decoder=True, is_decoder=True,
@ -513,7 +500,7 @@ class Base(nn.Module):
num_hidden_layers=n_layers, num_hidden_layers=n_layers,
num_attention_heads=n_heads, num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0, attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads, num_key_value_heads=self.config.experimental.kv_heads if self.config is not None and self.config.experimental.kv_heads > 0 else n_heads,
sliding_window=75 * 12, # 12 second context window sliding_window=75 * 12, # 12 second context window
output_router_logits=training, output_router_logits=training,
hidden_act="gelu", hidden_act="gelu",
@ -529,9 +516,6 @@ class Base(nn.Module):
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False use_reentrant=False
)) ))
#if training:
# self.model.training = True
elif self.arch_type == "llama": elif self.arch_type == "llama":
if n_experts <= 1: if n_experts <= 1:
self.model = LlamaModel(LlamaConfig( self.model = LlamaModel(LlamaConfig(
@ -575,9 +559,6 @@ class Base(nn.Module):
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False use_reentrant=False
)) ))
#if training:
# self.model.training = True
elif self.arch_type == "retnet": elif self.arch_type == "retnet":
kwargs = dict( kwargs = dict(
vocab_size=n_resp_tokens, vocab_size=n_resp_tokens,
@ -695,9 +676,6 @@ class Base(nn.Module):
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False use_reentrant=False
)) ))
#if training:
# self.model.training = True
else: else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}') raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
@ -970,7 +948,7 @@ class Base(nn.Module):
task_list.append( input ) task_list.append( input )
elif name == "prom": elif name == "prom":
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens # ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.config.audio_embedding_sums): if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
target.append( torch.full_like(input[..., 0], self.ignore_index) ) target.append( torch.full_like(input[..., 0], self.ignore_index) )
# we *CAN* directly map to proms # we *CAN* directly map to proms
else: else:

View File

@ -251,7 +251,7 @@ def example_usage():
kwargs = {} kwargs = {}
model = Model(**kwargs).to(device) model = Model(**kwargs).to(device)
steps = 100 steps = 100 if cfg.model.experimental.interleave else 300
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
@ -324,7 +324,7 @@ def example_usage():
engine.eval() engine.eval()
batch_size = len(text_list) batch_size = len(text_list)
resp_list = None resp_list = None
if cfg.model.interleave: if cfg.model.experimental.interleave:
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list) input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list)
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False) output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
@ -382,7 +382,7 @@ def example_usage():
stats = {"step": i} stats = {"step": i}
batch_size = len(text_list) batch_size = len(text_list)
quant_levels = None if cfg.model.interleave else torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,)) quant_levels = None if cfg.model.experimental.interleave else torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
if quant_levels is not None: if quant_levels is not None:
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ] resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
else: else:

View File

@ -27,6 +27,12 @@ class NAR(Base):
return self.config.capabilities return self.config.capabilities
return cfg.model.capabilities return cfg.model.capabilities
@property
def quant_level_range(self) -> list[int]:
if hasattr(self, "config") and self.config.rvq_level_range:
return self.config.rvq_level_range
return [ 0 if self.causal else 1, self.n_resp_levels ]
@property @property
def causal(self): def causal(self):
return "len" in self.capabilities return "len" in self.capabilities
@ -65,6 +71,12 @@ class NAR(Base):
return self.config.tasks return self.config.tasks
return cfg.model.tasks return cfg.model.tasks
@property
def p_rvq_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.p_rvq_levels
return cfg.model.p_rvq_levels
@property @property
def n_langs(self) -> int: def n_langs(self) -> int:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
@ -84,14 +96,6 @@ class NAR(Base):
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it # could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
return 1 # if self.causal else 0 return 1 # if self.causal else 0
@property
def interleave(self) -> bool:
return False
@property
def monolithic(self) -> bool:
return True
@property @property
def version(self) -> int: def version(self) -> int:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
@ -155,9 +159,12 @@ class NAR(Base):
task_list = [ sample_task() for _ in range(batch_size) ] task_list = [ sample_task() for _ in range(batch_size) ]
# determines which RVQ level to target per batch # determines which RVQ level to target per batch
quant_level_range = [ 0 if self.causal else 1, self.n_resp_levels ] quant_level_range = self.quant_level_range
if cfg.experimental: if self.p_rvq_levels == "equal":
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
else: # if self.p_rvq_levels == "auto":
# makes higher levels less likely # makes higher levels less likely
def generate( lo=0, hi=8 ): def generate( lo=0, hi=8 ):
index = lo index = lo
@ -167,18 +174,16 @@ class NAR(Base):
index = i index = i
return int(index) return int(index)
quant_levels = [ 0 if task_list[i] == "len" else generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ] quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
else:
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC... # clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
for i, prom in enumerate(proms_list): for i in range(batch_size):
if quant_levels[i] + 1 > prom.shape[-1]: # cap quant_level if it exceeds its corresponding resp/prom
quant_levels[i] = prom.shape[-1] - 1 if quant_levels[i] >= resps_list[i].shape[-1]:
for i, resp in enumerate(resps_list): quant_levels[i] = resps_list[i].shape[-1] - 1
if quant_levels[i] + 1 > resp.shape[-1]:
quant_levels[i] = resp.shape[-1] - 1 if quant_levels[i] >= proms_list[i].shape[-1]:
quant_levels[i] = proms_list[i].shape[-1] - 1
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]

View File

@ -31,8 +31,8 @@ def train_feeder(engine, batch):
batch_size = len(batch["text"]) batch_size = len(batch["text"])
engine.current_batch_size = batch_size engine.current_batch_size = batch_size
if engine.hyper_config.experimental: if engine.hyper_config.experimental.hf:
if cfg.model.interleave: if engine.hyper_config.experimental.interleave:
quant_levels = 0 quant_levels = 0
resps_list = [ resp for resp in batch["resps"] ] resps_list = [ resp for resp in batch["resps"] ]
else: else:
@ -129,8 +129,8 @@ def run_eval(engines, eval_name, dl):
for name in engines: for name in engines:
engine = engines[name] engine = engines[name]
if engine.hyper_config.experimental: if engine.hyper_config.experimental.hf:
if cfg.model.interleave: if engine.hyper_config.experimental.interleave:
input_ids, attention_mask = fold_inputs( input_ids, attention_mask = fold_inputs(
text_list=batch["text"], text_list=batch["text"],
prom_list=batch["proms"], prom_list=batch["proms"],

View File

@ -78,7 +78,6 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--language", type=str, default="en") parser.add_argument("--language", type=str, default="en")
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second)) parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"]) parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"]) parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
@ -211,7 +210,6 @@ with ui:
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.") layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
layout["inference"]["inputs"]["max-seconds-context"] = gr.Slider(value=0.0, minimum=0.0, maximum=12.0, step=0.05, label="Context Length", info="Amount of generated audio to keep in the context during inference, in seconds. Set 0 to disable.")
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")