sanity cleanup: moved experimental features under its own thing
This commit is contained in:
parent
b21f74a5c5
commit
bc2a6fa756
|
@ -23,7 +23,8 @@ I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
|
|||
|
||||
## Pre-Trained Model
|
||||
|
||||
> [!NOTE] Pre-Trained weights aren't up to par as a pure zero-shot model at the moment, but are fine for finetuning / LoRAs.
|
||||
> [!NOTE]
|
||||
> Pre-Trained weights aren't up to par as a pure zero-shot model at the moment, but are fine for finetuning / LoRAs.
|
||||
|
||||
My pre-trained weights can be acquired from [here](https://huggingface.co/ecker/vall-e).
|
||||
|
||||
|
@ -70,7 +71,7 @@ If you already have a dataset you want, for example, your own large corpus or fo
|
|||
|
||||
Two dataset formats are supported:
|
||||
* the standard way:
|
||||
- dta is stored under `./training/data/{group}/{speaker}/{id}.{enc|dac}` as a NumPy file, where `enc` is for the EnCodec/Vocos backend, and `dac` for the Descript-Audio-Codec backend.
|
||||
- data is stored under `./training/data/{group}/{speaker}/{id}.{enc|dac}` as a NumPy file, where `enc` is for the EnCodec/Vocos backend, and `dac` for the Descript-Audio-Codec backend.
|
||||
- it is *highly* recommended to generate metadata to speed up dataset pre-load with `python3 -m vall_e.data --yaml="./training/config.yaml" --action=metadata`
|
||||
* using an HDF5 dataset:
|
||||
- you can convert from the standard way with the following command: `python3 -m vall_e.data --yaml="./training/config.yaml"` (metadata for dataset pre-load is generated alongside HDF5 creation)
|
||||
|
@ -105,7 +106,7 @@ loras:
|
|||
training: True
|
||||
```
|
||||
|
||||
And thats it. Training of the LoRA is done with the same command. Depending on the rank and alpha specified, the loss may be higher than it should, as the LoRA weights are initialized to appropriately random values. I found rank and alpha of 128 works fine.
|
||||
And that's it. Training of the LoRA is done with the same command. Depending on the rank and alpha specified, the loss may be higher than it should, as the LoRA weights are initialized to appropriately random values. I found `rank` and `alpha` of 128 works fine.
|
||||
|
||||
To export your LoRA weights, run `python3 -m vall_e.export --lora --yaml="./training/config.yaml"`. You *should* be able to have the LoRA weights loaded from a training checkpoint automagically for inferencing, but export them just to be safe.
|
||||
|
||||
|
@ -125,7 +126,7 @@ Creature comforts like `float16`, `amp`, and multi-GPU training *should* work un
|
|||
|
||||
#### Backend Architectures
|
||||
|
||||
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLm architectures:
|
||||
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLM architectures:
|
||||
|
||||
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
|
||||
+ I aim to utilize this for the foundational model, as I get to leverage a bunch of things tailored for LLaMA (and converting to them is rather easy).
|
||||
|
|
|
@ -17,7 +17,6 @@ def main():
|
|||
|
||||
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||
parser.add_argument("--max-nar-levels", type=int, default=7)
|
||||
parser.add_argument("--max-ar-context", type=int, default=-1)
|
||||
|
||||
parser.add_argument("--ar-temp", type=float, default=1.0)
|
||||
parser.add_argument("--nar-temp", type=float, default=0.01)
|
||||
|
@ -50,7 +49,6 @@ def main():
|
|||
out_path=args.out_path,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
max_ar_context=args.max_ar_context,
|
||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p, top_k=args.top_k,
|
||||
|
|
|
@ -195,6 +195,15 @@ class Dataset:
|
|||
def max_duration(self):
|
||||
return self.duration_range[1]
|
||||
|
||||
@dataclass()
|
||||
class ModelExperimentalSettings:
|
||||
hf: bool = False # strictly utilizes a HF model and handles converting input IDs / outputs accordingly
|
||||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
|
||||
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head
|
||||
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
||||
audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||
|
||||
# I really need to clean this up
|
||||
@dataclass()
|
||||
class Model:
|
||||
|
@ -209,21 +218,17 @@ class Model:
|
|||
experts: int = 1 # for mixtral / retnet-ts
|
||||
arch_type: str = "llama" # underling LM architecture used
|
||||
training: bool = True # I really need to attend to this
|
||||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
|
||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||
attention: str = "auto" # for llama arch_types: attention used
|
||||
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
||||
split_classifiers: bool = False # experimental, but each RVQ level gets its own classifier / output proj / LM head
|
||||
dropout: float = 0.1 # adjustable dropout value
|
||||
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
||||
loss_factors: dict = field(default_factory=lambda: {})
|
||||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
||||
experimental: str | None = None # for now it sets things to be HF compatible
|
||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||
audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
||||
|
||||
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
||||
|
||||
experimental: dict | ModelExperimentalSettings | None = None # experimental settings
|
||||
|
||||
def get(self, name=None):
|
||||
return [ self ] if not name or self.name == name else []
|
||||
|
@ -758,16 +763,27 @@ class Config(BaseConfig):
|
|||
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
||||
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
||||
|
||||
"""
|
||||
if self.models is not None:
|
||||
self.model = Model(**next(iter(self.models)))
|
||||
else:
|
||||
self.model = Model(**self.model)
|
||||
"""
|
||||
for model in self.models:
|
||||
if not isinstance( model, dict ):
|
||||
continue
|
||||
|
||||
if "audio_embedding_sums" not in model:
|
||||
continue
|
||||
|
||||
if not model["experimental"]:
|
||||
model["experimental"] = {}
|
||||
|
||||
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
|
||||
|
||||
|
||||
self.models = [ Model(**model) for model in self.models ]
|
||||
self.loras = [ LoRA(**lora) for lora in self.loras ]
|
||||
|
||||
for model in self.models:
|
||||
if not isinstance( model.experimental, dict ):
|
||||
continue
|
||||
model.experimental = ModelExperimentalSettings(**model.experimental)
|
||||
|
||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||
|
||||
self.evaluation = Evaluation(**self.evaluation)
|
||||
|
|
|
@ -761,7 +761,8 @@ class Dataset(_Dataset):
|
|||
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
||||
|
||||
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
||||
if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
||||
"""
|
||||
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
||||
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})]
|
||||
|
||||
if len(choices) > 0:
|
||||
|
@ -790,6 +791,7 @@ class Dataset(_Dataset):
|
|||
# might be better to decode => concat waveforms with silence in between => reencode
|
||||
# as you technically can't just append encodec sequences together like this without issues
|
||||
resps = torch.concat([ resps, qnt ])
|
||||
"""
|
||||
|
||||
task = "tts"
|
||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||
|
|
|
@ -119,7 +119,6 @@ class TTS():
|
|||
references,
|
||||
language="en",
|
||||
max_ar_steps=6 * cfg.dataset.frames_per_second,
|
||||
max_ar_context=-1,
|
||||
max_nar_levels=7,
|
||||
input_prompt_length=0.0,
|
||||
ar_temp=0.95,
|
||||
|
@ -182,7 +181,7 @@ class TTS():
|
|||
# prom size: 3
|
||||
if model_ar is not None:
|
||||
resps_list = model_ar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps,
|
||||
sampling_temperature=ar_temp,
|
||||
sampling_min_temperature=min_ar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
|
|
|
@ -19,7 +19,7 @@ def get_model(config, training=True):
|
|||
training = training,
|
||||
config = config,
|
||||
)
|
||||
elif config.experimental:
|
||||
elif config.experimental.hf:
|
||||
from .experimental import Model as Experimental
|
||||
model = Experimental(
|
||||
n_text_tokens=config.text_tokens,
|
||||
|
|
|
@ -98,20 +98,6 @@ class AR_NAR(Base):
|
|||
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
|
||||
return 1 # if self.causal else 0
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def monolithic(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def audio_embeddings_mode(self) -> bool:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.audio_embeddings_mode
|
||||
return cfg.model.audio_embeddings_mode
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
|
@ -144,7 +130,6 @@ class AR_NAR(Base):
|
|||
|
||||
max_steps: int = 1000,
|
||||
max_levels: int = 0,
|
||||
max_resp_context: int = -1,
|
||||
|
||||
sampling_temperature: float = 1.0,
|
||||
sampling_min_temperature: float = -1.0,
|
||||
|
@ -305,17 +290,9 @@ class AR_NAR(Base):
|
|||
|
||||
scores = [ 1.0 ] * sampling_beam_width
|
||||
|
||||
if self.interleave:
|
||||
max_steps *= self.n_prom_levels
|
||||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.causal_size), desc="AR"):
|
||||
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this.
|
||||
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
|
||||
if max_resp_context > 0:
|
||||
resps_list = self._unsqueeze_list([ sequence[-max_resp_context:] for sequence in sequence_list ] )
|
||||
else:
|
||||
resps_list = self._unsqueeze_list(sequence_list)
|
||||
resps_list = self._unsqueeze_list(sequence_list)
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
|
|
|
@ -278,7 +278,6 @@ class Metrics(nn.Module):
|
|||
|
||||
class Base(nn.Module):
|
||||
# to-do: clean up this property mess
|
||||
|
||||
@property
|
||||
def causal(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
@ -319,21 +318,9 @@ class Base(nn.Module):
|
|||
def causal_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def monolithic(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def audio_embeddings_mode(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
return 1
|
||||
return 2
|
||||
|
||||
@property
|
||||
def capabilities(self) -> list[str]:
|
||||
|
@ -403,8 +390,9 @@ class Base(nn.Module):
|
|||
n_resp_tokens = n_audio_tokens
|
||||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
|
||||
audio_embedding_sums = self.config.audio_embedding_sums if self.config is not None else True
|
||||
split_classifiers = self.config.split_classifiers if self.config is not None else True
|
||||
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
||||
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
||||
audio_embeddings_mode = self.config.experimental.audio_embeddings_mode if self.config is not None else ""
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -432,12 +420,12 @@ class Base(nn.Module):
|
|||
self.proms_emb = AudioEmbedding(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
external_mode=self.audio_embeddings_mode,
|
||||
external_mode=audio_embeddings_mode,
|
||||
)
|
||||
self.resps_emb = AudioEmbedding(
|
||||
l_tokens, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
external_mode=self.audio_embeddings_mode,
|
||||
external_mode=audio_embeddings_mode,
|
||||
)
|
||||
|
||||
# useless since I actually removed using these with the input processing overhaul...
|
||||
|
@ -476,7 +464,6 @@ class Base(nn.Module):
|
|||
if self.config.attention not in AVAILABLE_ATTENTIONS:
|
||||
raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
|
||||
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
self.sin_emb = SinusoidalEmbedding(d_model)
|
||||
self.blocks = nn.ModuleList([TransformerBlock(
|
||||
|
@ -497,7 +484,7 @@ class Base(nn.Module):
|
|||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
|
||||
num_key_value_heads=self.config.experimental.kv_heads if self.config is not None and self.config.experimental.kv_heads > 0 else n_heads,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
|
@ -513,7 +500,7 @@ class Base(nn.Module):
|
|||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
|
||||
num_key_value_heads=self.config.experimental.kv_heads if self.config is not None and self.config.experimental.kv_heads > 0 else n_heads,
|
||||
sliding_window=75 * 12, # 12 second context window
|
||||
output_router_logits=training,
|
||||
hidden_act="gelu",
|
||||
|
@ -529,9 +516,6 @@ class Base(nn.Module):
|
|||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
|
||||
#if training:
|
||||
# self.model.training = True
|
||||
elif self.arch_type == "llama":
|
||||
if n_experts <= 1:
|
||||
self.model = LlamaModel(LlamaConfig(
|
||||
|
@ -575,9 +559,6 @@ class Base(nn.Module):
|
|||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
|
||||
#if training:
|
||||
# self.model.training = True
|
||||
elif self.arch_type == "retnet":
|
||||
kwargs = dict(
|
||||
vocab_size=n_resp_tokens,
|
||||
|
@ -695,9 +676,6 @@ class Base(nn.Module):
|
|||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
|
||||
#if training:
|
||||
# self.model.training = True
|
||||
else:
|
||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||
|
||||
|
@ -970,7 +948,7 @@ class Base(nn.Module):
|
|||
task_list.append( input )
|
||||
elif name == "prom":
|
||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||
if self.version < 4 or (self.version >= 5 and self.config.audio_embedding_sums):
|
||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||
# we *CAN* directly map to proms
|
||||
else:
|
||||
|
|
|
@ -251,7 +251,7 @@ def example_usage():
|
|||
|
||||
kwargs = {}
|
||||
model = Model(**kwargs).to(device)
|
||||
steps = 100
|
||||
steps = 100 if cfg.model.experimental.interleave else 300
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
@ -324,7 +324,7 @@ def example_usage():
|
|||
engine.eval()
|
||||
batch_size = len(text_list)
|
||||
resp_list = None
|
||||
if cfg.model.interleave:
|
||||
if cfg.model.experimental.interleave:
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list)
|
||||
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
|
||||
|
||||
|
@ -382,7 +382,7 @@ def example_usage():
|
|||
stats = {"step": i}
|
||||
|
||||
batch_size = len(text_list)
|
||||
quant_levels = None if cfg.model.interleave else torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
|
||||
quant_levels = None if cfg.model.experimental.interleave else torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
|
||||
if quant_levels is not None:
|
||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
||||
else:
|
||||
|
|
|
@ -27,6 +27,12 @@ class NAR(Base):
|
|||
return self.config.capabilities
|
||||
return cfg.model.capabilities
|
||||
|
||||
@property
|
||||
def quant_level_range(self) -> list[int]:
|
||||
if hasattr(self, "config") and self.config.rvq_level_range:
|
||||
return self.config.rvq_level_range
|
||||
return [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||
|
||||
@property
|
||||
def causal(self):
|
||||
return "len" in self.capabilities
|
||||
|
@ -64,6 +70,12 @@ class NAR(Base):
|
|||
if hasattr(self, "config") and self.config:
|
||||
return self.config.tasks
|
||||
return cfg.model.tasks
|
||||
|
||||
@property
|
||||
def p_rvq_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.p_rvq_levels
|
||||
return cfg.model.p_rvq_levels
|
||||
|
||||
@property
|
||||
def n_langs(self) -> int:
|
||||
|
@ -84,14 +96,6 @@ class NAR(Base):
|
|||
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
|
||||
return 1 # if self.causal else 0
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def monolithic(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
|
@ -155,9 +159,12 @@ class NAR(Base):
|
|||
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||
|
||||
# determines which RVQ level to target per batch
|
||||
quant_level_range = [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||
quant_level_range = self.quant_level_range
|
||||
|
||||
if cfg.experimental:
|
||||
if self.p_rvq_levels == "equal":
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||
else: # if self.p_rvq_levels == "auto":
|
||||
# makes higher levels less likely
|
||||
def generate( lo=0, hi=8 ):
|
||||
index = lo
|
||||
|
@ -167,18 +174,16 @@ class NAR(Base):
|
|||
index = i
|
||||
return int(index)
|
||||
|
||||
quant_levels = [ 0 if task_list[i] == "len" else generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
else:
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
|
||||
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
|
||||
for i, prom in enumerate(proms_list):
|
||||
if quant_levels[i] + 1 > prom.shape[-1]:
|
||||
quant_levels[i] = prom.shape[-1] - 1
|
||||
for i, resp in enumerate(resps_list):
|
||||
if quant_levels[i] + 1 > resp.shape[-1]:
|
||||
quant_levels[i] = resp.shape[-1] - 1
|
||||
for i in range(batch_size):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
if quant_levels[i] >= resps_list[i].shape[-1]:
|
||||
quant_levels[i] = resps_list[i].shape[-1] - 1
|
||||
|
||||
if quant_levels[i] >= proms_list[i].shape[-1]:
|
||||
quant_levels[i] = proms_list[i].shape[-1] - 1
|
||||
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
|
||||
|
|
|
@ -31,8 +31,8 @@ def train_feeder(engine, batch):
|
|||
batch_size = len(batch["text"])
|
||||
engine.current_batch_size = batch_size
|
||||
|
||||
if engine.hyper_config.experimental:
|
||||
if cfg.model.interleave:
|
||||
if engine.hyper_config.experimental.hf:
|
||||
if engine.hyper_config.experimental.interleave:
|
||||
quant_levels = 0
|
||||
resps_list = [ resp for resp in batch["resps"] ]
|
||||
else:
|
||||
|
@ -129,8 +129,8 @@ def run_eval(engines, eval_name, dl):
|
|||
for name in engines:
|
||||
engine = engines[name]
|
||||
|
||||
if engine.hyper_config.experimental:
|
||||
if cfg.model.interleave:
|
||||
if engine.hyper_config.experimental.hf:
|
||||
if engine.hyper_config.experimental.interleave:
|
||||
input_ids, attention_mask = fold_inputs(
|
||||
text_list=batch["text"],
|
||||
prom_list=batch["proms"],
|
||||
|
|
|
@ -78,7 +78,6 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--language", type=str, default="en")
|
||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
|
||||
parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*cfg.dataset.frames_per_second))
|
||||
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
|
||||
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
|
||||
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
|
||||
|
@ -211,7 +210,6 @@ with ui:
|
|||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||
layout["inference"]["inputs"]["max-seconds-context"] = gr.Slider(value=0.0, minimum=0.0, maximum=12.0, step=0.05, label="Context Length", info="Amount of generated audio to keep in the context during inference, in seconds. Set 0 to disable.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
||||
|
|
Loading…
Reference in New Issue
Block a user